From 369dc3d645124a79feaf5aa241987a4c46381b1c Mon Sep 17 00:00:00 2001 From: Tejas Shah <shatejas@amazon.com> Date: Tue, 13 Aug 2024 16:45:44 -0700 Subject: [PATCH] Iterative index integration (#1956) * Iterative Vector Insertion (#1840) * Rebased with new version of k-NN Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Optimized faiss insertion Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Optimized threadCount logic Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed IDEA files Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed unnecessary cmake file Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added comments to new functions Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed createIndex and fixed test cases that use it Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed unused code Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Explained zero initialization for vector transfer Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added locale Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Spotless Apply Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Account for zero documents in finished batch Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed where we check for zero docs Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed tip for return Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Use unique pointers to make sure resources are released on exception Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Moved createIndex to testUtils Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed memory management so that the underlying index is not deleted after initialized Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Created new KNNIndexBuilder graph to make index building more modular Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Streamlined logic in KNNIndexBuilder. Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Cleaned up unnecessary code in KNN80DocValuesConsumer Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed memory management process Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added note about index initialization in faiss_index_service Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Accounted for case where the exception happens after the indexWriter is released. Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Delete jni/src/.idea/modules.xml Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Delete jni/src/.idea/vcs.xml Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Delete jni/src/.idea/workspace.xml Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Spotless apply and free iterative index on exception Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Undid hack for checking first document metrics Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed print statements Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Free Vector Transfer on batch ingestion Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Undid free Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed check for transfer ready Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Don't crash when zero vectors inserted? Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Reverted to old insertion process? Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Spotless apply Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added back createOutput Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed prior createOutput Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Test remaking vectorTransfer Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Test restructuring of insertion Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed case where vector address is immediately discarded Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Spotless apply Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Split Index Builder into multiple classes Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed descriptions of functions in faiss_index_service Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added back copyright files Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed unused builder names Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Modified tests to work with new insertion methods Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Track index insertions Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Tracked insertions for binary indices Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added back insertIds Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added check for freeVectorData to see if it works with an already deleted address Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Cleaned up logs and comments in KNNIndexBuilder Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Restructured the logic for KNNIndexBuilder Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed package name of KNNIndexBuilder Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed all package names and deleted unnecessary headers Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed for loop Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed createIndex methods for faiss index service Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed package to fit naming conventions Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed name of index builder Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Spotless apply Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added comments to NativeIndexBuilder and restructured Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added deletion for memoryAddress Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Spotless apply Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed naming of classes to Writer and changed package name to fit conventions Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed NativeIndexInfo and NativeVectorInfo to follow builder pattern Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added feature to changelog Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added class descriptions to each NativeIndexWriter Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed name to getBytesPerVector Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added == false instead of ! for readability Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed changelog Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed naming in docvaluesconsumer Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * SpotlessApply Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Made it so that we don't reuse testValues and removed a foot gun Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed another foot gun in getIndexInfo Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed javadoc Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added deletion on exception cases Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed unnecessary delete (NativeIndexWriter will handle deletion of vectors on exception) Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added correct logger and getWriter method to NativeIndexWriter Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Ensured memory safety on JNI layer so that Java doesn't have to wrap everything in a try catch loop. Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Refactored NativeIndexWriter and added comments to FaissService Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed free in the JNIExport since index will always be freed in writeIndex. Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Changed getVectorTransfer back to accept VectorDataType Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Reverted free since not guaranteed to be IDMap. Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added all processes in addKNNBinaryField to NativeIndexWriter.createKNNIndex Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed javadoc Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Applied spotless Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Added back writeFooter Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed threadCount fron writeIndex Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed redundancies in KNN80DocValuesConsumer Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed serializationMode Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed changelog Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed changelog Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed double free test as we don't have to worry about that anymore Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Accounted for HNSWSQ in index service Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed delete in catch Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed faiss tests to work with writeIndex Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> --------- Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Index Initialization Alloc Method (#1933) * Added methods for allocating memory before inserting vectors to a faiss index Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed logic that gets type of index Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed print statement Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed unnecessary iostream Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed flat index Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed flat index case Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed naming Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Properly allocate HNSWSQ storage Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed print statements Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Fixed changelog Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Removed unnecessary lib Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Made alloc adaptive to different code sizes Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> --------- Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> * Integrates FAISS iterative builds with NativeEngines990KnnVectorsFormat Changes include reusing the same vector buffer in the JNI layer Signed-off-by: Tejas Shah <shatejas@amazon.com> --------- Signed-off-by: Andrew Klepchick <aklepchi@amazon.com> Signed-off-by: Tejas Shah <shatejas@amazon.com> Co-authored-by: Andrew Klepchick <aklepchi@amazon.com> --- CHANGELOG.md | 7 +- jni/include/commons.h | 17 +- jni/include/faiss_index_service.h | 87 +++--- jni/include/faiss_wrapper.h | 9 +- .../org_opensearch_knn_jni_FaissService.h | 49 ++- .../org_opensearch_knn_jni_JNICommons.h | 6 +- jni/src/commons.cpp | 14 +- jni/src/faiss_index_service.cpp | 170 ++++++++--- jni/src/faiss_wrapper.cpp | 72 +++-- .../org_opensearch_knn_jni_FaissService.cpp | 77 ++++- jni/src/org_opensearch_knn_jni_JNICommons.cpp | 8 +- jni/tests/commons_test.cpp | 116 +++++++- jni/tests/faiss_index_service_test.cpp | 30 +- jni/tests/faiss_wrapper_test.cpp | 94 +++++- jni/tests/mocks/faiss_index_service_mock.h | 25 +- .../knn/common/FieldInfoExtractor.java | 18 ++ .../opensearch/knn/common/KNNVectorUtil.java | 18 +- .../codec/BasePerFieldKnnVectorsFormat.java | 69 +++-- .../KNN80Codec/KNN80DocValuesConsumer.java | 265 +---------------- .../NativeEngineFieldVectorsWriter.java | 3 + .../NativeEngines990KnnVectorsWriter.java | 44 ++- .../MemOptimizedNativeIndexBuildStrategy.java | 89 ++++++ .../nativeindex/NativeIndexBuildStrategy.java | 19 ++ .../codec/nativeindex/NativeIndexWriter.java | 280 ++++++++++++++++++ .../VectorTransferIndexBuildStrategy.java | 91 ++++++ .../nativeindex/model/BuildIndexParams.java | 23 ++ .../OffHeapBinaryQuantizedVectorTransfer.java | 40 +++ .../OffHeapByteQuantizedVectorTransfer.java | 40 +++ .../transfer/OffHeapFloatVectorTransfer.java | 43 +++ .../OffHeapQuantizedVectorTransfer.java | 124 ++++++++ .../index/codec/transfer/VectorTransfer.java | 49 ++- .../codec/transfer/VectorTransferByte.java | 68 ----- .../codec/transfer/VectorTransferFloat.java | 71 ----- .../knn/index/codec/util/KNNCodecUtil.java | 49 --- .../vectorvalues/KNNBinaryVectorValues.java | 17 +- .../vectorvalues/KNNByteVectorValues.java | 12 + .../vectorvalues/KNNFloatVectorValues.java | 12 + .../index/vectorvalues/KNNVectorValues.java | 36 ++- .../org/opensearch/knn/indices/ModelUtil.java | 7 +- .../org/opensearch/knn/jni/FaissService.java | 64 +++- .../org/opensearch/knn/jni/JNICommons.java | 76 ++++- .../org/opensearch/knn/jni/JNIService.java | 197 ++++++++---- .../KNN80DocValuesConsumerTests.java | 82 +++-- ...NativeEngines990KnnVectorsFormatTests.java | 14 +- .../transfer/OffHeapVectorTransferTests.java | 134 +++++++++ .../transfer/VectorTransferByteTests.java | 56 ---- .../transfer/VectorTransferFloatTests.java | 66 ----- .../index/codec/util/KNNCodecUtilTests.java | 43 --- .../memory/NativeMemoryAllocationTests.java | 5 +- .../memory/NativeMemoryLoadStrategyTests.java | 7 +- .../opensearch/knn/jni/JNIServiceTests.java | 93 +++--- .../java/org/opensearch/knn/TestUtils.java | 17 +- 52 files changed, 2091 insertions(+), 1031 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java delete mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java delete mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index a5c641b8f..555983d01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Enhancements +* Add functionality to iteratively insert vectors into a faiss index to improve the memory footprint during indexing. [#1840](https://github.com/opensearch-project/k-NN/pull/1840) ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) -* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) -* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) +* Disallow a vector field to have an invalid character for a physical file name. [#1936] (https://github.com/opensearch-project/k-NN/pull/1936) +* Fixed and abstracted functionality for allocating index memory [#1933](https://github.com/opensearch-project/k-NN/pull/1933) ### Infrastructure ### Documentation ### Maintenance diff --git a/jni/include/commons.h b/jni/include/commons.h index d02439377..4cdaf28fc 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -19,12 +19,19 @@ namespace knn_jni { * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location * will throw Exception. * + * append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector + * without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating + * and deallocating when the memory address needs to be reused. + * + * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called + * * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append whether to append or start from index 0 when called subsequently with the same address * @return memory address of std::vector<float> where the data is stored. */ - jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** * This is utility function that can be used to store data in native memory. This function will allocate memory for @@ -33,12 +40,18 @@ namespace knn_jni { * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location * will throw Exception. * + * append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector + * without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating + * and deallocating when the memory address needs to be reused. + * + * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called + * * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D byte array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address of std::vector<uint8_t> where the data is stored. */ - jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9..c57309cfc 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -31,38 +31,41 @@ namespace faiss_wrapper { class IndexService { public: IndexService(std::unique_ptr<FaissMethods> faissMethods); - //TODO Remove dependency on JNIUtilInterface and JNIEnv - //TODO Reduce the number of parameters - /** - * Create index + * Initialize index * * @param jniUtil jni util * @param env jni environment * @param metric space type for distance calculation * @param indexDescription index description to be used by faiss index factory * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map<std::string, jobject> parameters); + /** + * Add vectors to index + * + * @param dim dimension of vectors * @param numIds number of vectors * @param threadCount number of thread count to be used while adding data * @param vectorsAddress memory address which is holding vector data - * @param ids a list of document ids for corresponding vectors + * @param idMapAddress memory address of the native index object + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector<int64_t> &ids, jlong idMapAddress); + /** + * Write index to disk + * + * @param threadCount number of thread count to be used while adding data * @param indexPath path to write index - * @param parameters parameters to be applied to faiss index + * @param idMap memory address of the native index object */ - virtual void createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector<int64_t> ids, - std::string indexPath, - std::unordered_map<std::string, jobject> parameters); + virtual void writeIndex(std::string indexPath, jlong idMapAddress); virtual ~IndexService() = default; protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors); std::unique_ptr<FaissMethods> faissMethods; }; @@ -76,7 +79,21 @@ class BinaryIndexService : public IndexService { //TODO Reduce the number of parameters BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods); /** - * Create binary index + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map<std::string, jobject> parameters) override; + /** + * Add vectors to index * * @param jniUtil jni util * @param env jni environment @@ -86,28 +103,30 @@ class BinaryIndexService : public IndexService { * @param numIds number of vectors * @param threadCount number of thread count to be used while adding data * @param vectorsAddress memory address which is holding vector data - * @param ids a list of document ids for corresponding vectors + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector<int64_t> &ids, jlong idMapAddress) override; + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data * @param indexPath path to write index + * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector<int64_t> ids, - std::string indexPath, - std::unordered_map<std::string, jobject> parameters - ) override; + virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; virtual ~BinaryIndexService() = default; +protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; }; } } -#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H +#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H \ No newline at end of file diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 5ad0dedc4..574efb6fd 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -18,10 +18,11 @@ namespace knn_jni { namespace faiss_wrapper { - // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. - // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ, IndexService* indexService); + jlong InitIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, IndexService *indexService); + + void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddr, jint threadCount, IndexService *indexService); + + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jlong indexAddr, IndexService *indexService); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 025fb12e8..19e13d402 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,23 +18,54 @@ #ifdef __cplusplus extern "C" { #endif - /* * Class: org_opensearch_knn_jni_FaissService - * Method: createIndex + * Method: initIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); - +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService - * Method: createBinaryIndex + * Method: initBinaryIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); - +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h index 89de76520..03c0d023a 100644 --- a/jni/include/org_opensearch_knn_jni_JNICommons.h +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -21,10 +21,10 @@ extern "C" { /* * Class: org_opensearch_knn_jni_JNICommons * Method: storeVectorData - * Signature: (J[[FJJ) + * Signature: (J[[FJJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData - (JNIEnv *, jclass, jlong, jobjectArray, jlong); + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); /* * Class: org_opensearch_knn_jni_JNICommons @@ -32,7 +32,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData * Signature: (J[[FJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData - (JNIEnv *, jclass, jlong, jobjectArray, jlong); + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); /* * Class: org_opensearch_knn_jni_JNICommons diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index 13f59194e..f9764db73 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -18,7 +18,7 @@ #include "commons.h" jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, - jobjectArray dataJ, jlong initialCapacityJ) { + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector<float> *vect; if ((long) memoryAddressJ == 0) { vect = new std::vector<float>(); @@ -26,6 +26,11 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE } else { vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ); } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect); @@ -33,7 +38,7 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE } jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, - jobjectArray dataJ, jlong initialCapacityJ) { + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector<uint8_t> *vect; if ((long) memoryAddressJ == 0) { vect = new std::vector<uint8_t>(); @@ -41,6 +46,11 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, } else { vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddressJ); } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect); diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af..69866da76 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -57,76 +57,157 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, IndexService::IndexService(std::unique_ptr<FaissMethods> faissMethods) : faissMethods(std::move(faissMethods)) {} -void IndexService::createIndex( +void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexHNSWSQ = dynamic_cast<faiss::IndexHNSWSQ *>(index)) { + if(auto * indexScalarQuantizer = dynamic_cast<faiss::IndexScalarQuantizer *>(indexHNSWSQ->storage)) { + indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors); + } + return; + } + if(auto * indexHNSW = dynamic_cast<faiss::IndexHNSW *>(index)) { + if(auto * indexFlat = dynamic_cast<faiss::IndexFlat *>(indexHNSW->storage)) { + indexFlat->codes.reserve(indexFlat->code_size * numVectors); + } + return; + } +} + +jlong IndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map<std::string, jobject> parameters + ) { + // Create index using Faiss factory method + std::unique_ptr<faiss::Index> indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(indexWriter.get())); + + allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors); + indexWriter.release(); + return reinterpret_cast<jlong>(idMap.release()); +} + +void IndexService::insertToIndex( int dim, int numIds, int threadCount, int64_t vectorsAddress, - std::vector<int64_t> ids, - std::string indexPath, - std::unordered_map<std::string, jobject> parameters + std::vector<int64_t> & ids, + jlong idMapAddress ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress); + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector<float> * inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress); // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) dim); if(numVectors == 0) { - throw std::runtime_error("Number of vectors cannot be 0"); + return; } if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - std::unique_ptr<faiss::Index> indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory - SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + faiss::IndexIDMap * idMap = reinterpret_cast<faiss::IndexIDMap *> (idMapAddress); // Add vectors - std::unique_ptr<faiss::IndexIDMap> idMap(faissMethods->indexIdMap(indexWriter.get())); idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} - // Write the index to disk - faissMethods->writeIndex(idMap.get(), indexPath.c_str()); +void IndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + std::unique_ptr<faiss::IndexIDMap> idMap (reinterpret_cast<faiss::IndexIDMap *> (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + delete idMap->index; + throw std::runtime_error("Failed to write index to disk"); + } + // Free the memory used by the index + delete idMap->index; } BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {} -void BinaryIndexService::createIndex( +void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexBinaryHNSW = dynamic_cast<faiss::IndexBinaryHNSW *>(index)) { + auto * indexBinaryFlat = dynamic_cast<faiss::IndexBinaryFlat *>(indexBinaryHNSW->storage); + indexBinaryFlat->xb.reserve(dim * numVectors / 8); + return; + } +} + +jlong BinaryIndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, int dim, - int numIds, + int numVectors, int threadCount, - int64_t vectorsAddress, - std::vector<int64_t> ids, - std::string indexPath, std::unordered_map<std::string, jobject> parameters ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddress); + // Create index using Faiss factory method + std::unique_ptr<faiss::IndexBinary> indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - if (dim % 8 != 0) { - throw std::runtime_error("Dimensions should be multiply of 8"); + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + + allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors); + indexWriter.release(); + return reinterpret_cast<jlong>(idMap.release()); +} + +void BinaryIndexService::insertToIndex( + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector<int64_t> & ids, + jlong idMapAddress + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector<uint8_t> * inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddress); + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); if(numVectors == 0) { @@ -137,28 +218,35 @@ void BinaryIndexService::createIndex( throw std::runtime_error("Number of IDs does not match number of vectors"); } - std::unique_ptr<faiss::IndexBinary> indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory - SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + faiss::IndexBinaryIDMap * idMap = reinterpret_cast<faiss::IndexBinaryIDMap *> (idMapAddress); // Add vectors - std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} + +void BinaryIndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + + std::unique_ptr<faiss::IndexBinaryIDMap> idMap (reinterpret_cast<faiss::IndexBinaryIDMap *> (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + delete idMap->index; + throw std::runtime_error("Failed to write index to disk"); + } - // Write the index to disk - faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + // Free the memory used by the index + delete idMap->index; } } // namespace faiss_wrapper -} // namesapce knn_jni +} // namesapce knn_jni \ No newline at end of file diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 1d4437414..0e1029ecf 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -88,24 +88,13 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); -void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ, IndexService* indexService) { - if (idsJ == nullptr) { - throw std::runtime_error("IDs cannot be null"); - } - - if (vectorsAddressJ <= 0) { - throw std::runtime_error("VectorsAddress cannot be less than 0"); - } +jlong knn_jni::faiss_wrapper::InitIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, IndexService* indexService) { if(dimJ <= 0) { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); - } - if (parametersJ == nullptr) { throw std::runtime_error("Parameters cannot be null"); } @@ -124,8 +113,8 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN // Dimension int dim = (int)dimJ; - // Number of vectors - int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + // Number of docs + int docs = (int)numDocs; // Index description jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); @@ -138,25 +127,60 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); } + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map<std::string, jobject> subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); +} + +void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong index_ptr, jint threadCount, IndexService* indexService) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + // Dimension + int dim = (int)dimJ; + + // Number of vectors + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + // Vectors address int64_t vectorsAddress = (int64_t)vectorsAddressJ; // Ids auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - // Index path - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + // Create index + indexService->insertToIndex(dim, numIds, threadCount, vectorsAddress, ids, index_ptr); +} - // Extra parameters - // TODO: parse the entire map and remove jni object - std::unordered_map<std::string, jobject> subParametersCpp; - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { - subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); +void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, + jstring indexPathJ, jlong index_ptr, IndexService* indexService) { + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); } - // end parameters to pass + + // Index path + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); // Create index - indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp); + indexService->writeIndex(indexPathCpp, index_ptr); } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 2394e2951..6e7dd4912 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -39,37 +39,84 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { jniUtil.Uninitialize(env); } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) { try { std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &indexService); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - delete reinterpret_cast<std::vector<float>*>(vectorsAddressJ); +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &binaryIndexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } + return (jlong)0; } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) + jlong indexAddress, jint threadCount) { try { std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); - knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &binaryIndexService); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &indexService); + } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) +{ + try { + std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &binaryIndexService); delete reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ); + } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &binaryIndexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp index 0bc2e4633..7432c44d3 100644 --- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -38,11 +38,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls, -jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { try { - return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -50,11 +50,11 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) } JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls, -jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { try { - return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp index 630358919..d469fe268 100644 --- a/jni/tests/commons_test.cpp +++ b/jni/tests/commons_test.cpp @@ -33,7 +33,7 @@ TEST(CommonsTests, BasicAssertions) { testing::NiceMock<test_util::MockJNIUtil> mockJNIUtil; jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0, - reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim)); + reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim), true); ASSERT_NE(memoryAddress, 0); auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddress); ASSERT_EQ(vect->size(), data.size() * dim); @@ -48,12 +48,13 @@ TEST(CommonsTests, BasicAssertions) { } data2.push_back(vector); memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, - reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim)); + reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim), true); ASSERT_NE(memoryAddress, 0); ASSERT_EQ(memoryAddress, oldMemoryAddress); vect = reinterpret_cast<std::vector<float>*>(memoryAddress); int currentIndex = 0; - ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + std::cout << vect->size() + "\n"; + ASSERT_EQ(vect->size(), totalNumberOfVector * dim); ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); // Validate if all vectors data are at correct location @@ -70,6 +71,115 @@ TEST(CommonsTests, BasicAssertions) { currentIndex++; } } + + // test append == true + std::vector<std::vector<float>> data3; + std::vector<float> vecto3; + for(int j = 0 ; j < dim ; j ++) { + vecto3.push_back((float)j); + } + data3.push_back(vecto3); + memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast<jobjectArray>(&data3), (jlong)(totalNumberOfVector * dim), false); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast<std::vector<float>*>(memoryAddress); + + ASSERT_EQ(vect->size(), dim); //Since we just added 1 vector + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); //This is the initial capacity allocated + + currentIndex = 0; + for(auto & i : data3) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // Check that freeing vector data works + knn_jni::commons::freeVectorData(memoryAddress); +} + +TEST(StoreByteVectorTest, BasicAssertions) { + long dim = 3; + long totalNumberOfVector = 5; + std::vector<std::vector<uint8_t>> data; + for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) { + std::vector<uint8_t> vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((uint8_t)j); + } + data.push_back(vector); + } + JNIEnv *jniEnv = nullptr; + + testing::NiceMock<test_util::MockJNIUtil> mockJNIUtil; + + jlong memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, (jlong)0, + reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim), true); + ASSERT_NE(memoryAddress, 0); + auto *vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddress); + ASSERT_EQ(vect->size(), data.size() * dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Check by inserting more vectors at same memory location + jlong oldMemoryAddress = memoryAddress; + std::vector<std::vector<uint8_t>> data2; + std::vector<uint8_t> vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((uint8_t)j); + } + data2.push_back(vector); + memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim), true); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddress); + int currentIndex = 0; + ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Validate if all vectors data are at correct location + for(auto & i : data) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + for(auto & i : data2) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // test append == true + std::vector<std::vector<uint8_t>> data3; + std::vector<uint8_t> vecto3; + for(int j = 0 ; j < dim ; j ++) { + vecto3.push_back((uint8_t)j); + } + data3.push_back(vecto3); + memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast<jobjectArray>(&data3), (jlong)(totalNumberOfVector * dim), false); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddress); + + ASSERT_EQ(vect->size(), dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + currentIndex = 0; + for(auto & i : data3) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // Check that freeing vector data works + knn_jni::commons::freeVectorData(memoryAddress); } TEST(CommonTests, GetIntegerMethodParam) { diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index f876edced..1f00f6a1d 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -64,18 +64,9 @@ TEST(CreateIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); - indexService.createIndex( - &mockJNIUtil, - jniEnv, - metricType, - indexDescription, - dim, - numIds, - threadCount, - (int64_t) &vectors, - ids, - indexPath, - parametersMap); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } TEST(CreateBinaryIndexTest, BasicAssertions) { @@ -119,16 +110,7 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(mockFaissMethods)); - indexService.createIndex( - &mockJNIUtil, - jniEnv, - metricType, - indexDescription, - dim, - numIds, - threadCount, - (int64_t) &vectors, - ids, - indexPath, - parametersMap); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5ae443837..a1839c6ce 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -32,6 +32,70 @@ float rangeSearchRandomDataMin = -50; float rangeSearchRandomDataMax = 50; float rangeSearchRadius = 20000; +void createIndexIteratively( + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, + std::vector<faiss::idx_t> & ids, + std::vector<float> & vectors, + int dim, + std::string & indexPath, + std::unordered_map<string, jobject> parametersMap, + IndexService * indexService, + int insertions = 10 + ) { + long numDocs = ids.size(); + if(numDocs % insertions != 0) { + throw std::invalid_argument("Number of documents should be divisible by number of insertions"); + } + long docsPerInsertion = numDocs / insertions; + long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); + for(int i = 0; i < insertions; i++) { + int start_idx = i * docsPerInsertion; + int end_idx = start_idx + docsPerInsertion; + std::vector<faiss::idx_t> insertIds; + std::vector<float> insertVecs; + for(int j = start_idx; j < end_idx; j++) { + insertIds.push_back(j); + for(int k = 0; k < dim; k++) { + insertVecs.push_back(vectors[j * dim + k]); + } + } + knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast<jintArray>(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); + } + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); +} + +void createBinaryIndexIteratively( + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, + std::vector<faiss::idx_t> & ids, + std::vector<uint8_t> & vectors, + int dim, + std::string & indexPath, + std::unordered_map<string, jobject> parametersMap, + IndexService * indexService, + int insertions = 10 + ) { + long numDocs = ids.size();; + long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); + for(int i = 0; i < insertions; i++) { + int start_idx = numDocs * i / insertions; + int end_idx = numDocs * (i + 1) / insertions; + int docs_to_insert = end_idx - start_idx; + if(docs_to_insert == 0) continue; + std::vector<faiss::idx_t> insertIds; + std::vector<float> insertVecs; + for(int j = start_idx; j < end_idx; j++) { + insertIds.push_back(j); + for(int k = 0; k < dim / 8; k++) { + insertVecs.push_back(vectors[j * (dim / 8) + k]); + } + } + knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast<jintArray>(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); + } + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); +} + TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; @@ -63,13 +127,15 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Create the index std::unique_ptr<FaissMethods> faissMethods(new FaissMethods()); NiceMock<MockIndexService> mockIndexService(std::move(faissMethods)); - EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) .Times(1); - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids), - (jlong) &vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap, &mockIndexService); + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); } TEST(FaissCreateBinaryIndexTest, BasicAssertions) { @@ -103,14 +169,16 @@ TEST(FaissCreateBinaryIndexTest, BasicAssertions) { // Create the index std::unique_ptr<FaissMethods> faissMethods(new FaissMethods()); NiceMock<MockIndexService> mockIndexService(std::move(faissMethods)); - EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) .Times(1); // This method calls delete vectors at the end - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids), - (jlong) &vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap, &mockIndexService); + createBinaryIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); } TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { @@ -683,10 +751,8 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Create the index std::unique_ptr<FaissMethods> faissMethods(new FaissMethods()); knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids), - (jlong)&vectors, dim, (jstring)&indexPath, - (jobject)¶metersMap, &IndexService); + + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &IndexService); // Make sure index can be loaded std::unique_ptr<faiss::Index> index(test_util::FaissLoadIndex(indexPath)); diff --git a/jni/tests/mocks/faiss_index_service_mock.h b/jni/tests/mocks/faiss_index_service_mock.h index 7af08c82e..285e34053 100644 --- a/jni/tests/mocks/faiss_index_service_mock.h +++ b/jni/tests/mocks/faiss_index_service_mock.h @@ -23,20 +23,37 @@ class MockIndexService : public IndexService { public: MockIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}; MOCK_METHOD( - void, - createIndex, + long, + initIndex, ( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, + int dim, + int numIds, + int threadCount, + StringToJObjectMap parameters + ), + (override)); + MOCK_METHOD( + void, + insertToIndex, + ( int dim, int numIds, int threadCount, int64_t vectorsAddress, - std::vector<int64_t> ids, + std::vector<int64_t> & ids, + long indexPtr + ), + (override)); + MOCK_METHOD( + void, + writeIndex, + ( std::string indexPath, - StringToJObjectMap parameters + long indexPtr ), (override)); }; diff --git a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java index 591f16735..bb7cce485 100644 --- a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java +++ b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java @@ -9,15 +9,33 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.indices.ModelUtil.getModelMetadata; + /** * A utility class to extract information from FieldInfo. */ @UtilityClass public class FieldInfoExtractor { + /** + * Extracts KNNEngine from FieldInfo + * @param field {@link FieldInfo} + * @return {@link KNNEngine} + */ + public static KNNEngine extractKNNEngine(final FieldInfo field) { + final ModelMetadata modelMetadata = getModelMetadata(field.attributes().get(MODEL_ID)); + if (modelMetadata != null) { + return modelMetadata.getKnnEngine(); + } + final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + return KNNEngine.getEngine(engineName); + } + /** * Extract vector data type from fieldInfo * @param fieldInfo {@link FieldInfo} diff --git a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java index fd9e5b6c2..9381f73e8 100644 --- a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java @@ -5,10 +5,12 @@ package org.opensearch.knn.common; -import java.util.Objects; import lombok.AccessLevel; import lombok.NoArgsConstructor; +import java.util.ArrayList; +import java.util.Objects; + @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorUtil { /** @@ -42,4 +44,18 @@ public static boolean isZeroVector(float[] vector) { } return true; } + + /** + * Creates an int overflow safe arraylist. If there is an overflow it will create a list with default initial size + * @param batchSize size to allocate + * @return an arrayList + */ + public static <T> ArrayList<T> createArrayList(long batchSize) { + try { + return new ArrayList<>(Math.toIntExact(batchSize)); + } catch (Exception exception) { + // No-op + } + return new ArrayList<>(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 69229036e..8beced605 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -8,8 +8,11 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; @@ -17,6 +20,7 @@ import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.function.Supplier; @@ -78,42 +82,47 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); - var params = knnMethodContext.getMethodComponentContext().getParameters(); + final KNNEngine engine = knnMethodContext.getKnnEngine(); + final Map<String, Object> params = knnMethodContext.getMethodComponentContext().getParameters(); - if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { - KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( - params, - defaultMaxConnections, - defaultBeamWidth - ); - if (knnScalarQuantizedVectorsFormatParams.validate(params)) { - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnScalarQuantizedVectorsFormatParams.getBeamWidth(), - LUCENE_SQ_CONFIDENCE_INTERVAL, - knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), - LUCENE_SQ_BITS, - knnScalarQuantizedVectorsFormatParams.getBits() + if (engine == KNNEngine.LUCENE) { + if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth ); - return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + if (knnScalarQuantizedVectorsFormatParams.validate(params)) { + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + LUCENE_SQ_CONFIDENCE_INTERVAL, + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + LUCENE_SQ_BITS, + knnScalarQuantizedVectorsFormatParams.getBits() + ); + return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + } } + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnVectorsFormatParams.getBeamWidth() + ); + return vectorsFormatSupplier.apply(knnVectorsFormatParams); } - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnVectorsFormatParams.getBeamWidth() - ); - return vectorsFormatSupplier.apply(knnVectorsFormatParams); + // All native engines to use NativeEngines990KnnVectorsFormat + return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 55ac5c597..218c9d891 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,76 +5,40 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.store.ChecksumIndexInput; import org.opensearch.common.StopWatch; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import org.opensearch.knn.index.codec.transfer.VectorTransferByte; -import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; -import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.store.FilterDirectory; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.stats.KNNGraphValue; -import java.io.Closeable; import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.HashMap; -import java.util.Map; -import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; -import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; /** * This class writes the KNN docvalues to the segments */ @Log4j2 -class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { +class KNN80DocValuesConsumer extends DocValuesConsumer { private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class); private final DocValuesConsumer delegatee; private final SegmentWriteState state; - private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; - KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) { this.delegatee = delegatee; this.state = state; @@ -86,7 +50,7 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th if (isKNNBinaryFieldRequired(field)) { StopWatch stopWatch = new StopWatch(); stopWatch.start(); - addKNNBinaryField(field, valuesProducer, false, true); + addKNNBinaryField(field, valuesProducer, false); stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); @@ -95,174 +59,21 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th } private boolean isKNNBinaryFieldRequired(FieldInfo field) { - final KNNEngine knnEngine = getKNNEngine(field); + final KNNEngine knnEngine = extractKNNEngine(field); log.debug(String.format("Read engine [%s] for field [%s]", knnEngine.getName(), field.getName())); return field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD) && KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine); } - private KNNEngine getKNNEngine(@NonNull FieldInfo field) { - final String modelId = field.attributes().get(MODEL_ID); - if (modelId != null) { - var model = ModelCache.getInstance().get(modelId); - return model.getModelMetadata().getKnnEngine(); - } - final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); - return KNNEngine.getEngine(engineName); - } - - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) - throws IOException { - // Get values to be indexed - BinaryDocValues values = valuesProducer.getBinary(field); - final KNNEngine knnEngine = getKNNEngine(field); - final String engineFileName = buildEngineFileName( - state.segmentInfo.name, - knnEngine.getVersion(), - field.name, - knnEngine.getExtension() - ); - final String indexPath = Paths.get( - ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName - ).toString(); - - // Determine if we are creating an index from a model or from scratch - NativeIndexCreator indexCreator; - KNNCodecUtil.Pair pair; - Map<String, String> fieldAttributes = field.attributes(); - VectorDataType vectorDataType; - - if (fieldAttributes.containsKey(MODEL_ID)) { - String modelId = fieldAttributes.get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - vectorDataType = model.getModelMetadata().getVectorDataType(); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); - } else { - // get vector data type from field attributes or provide default value - vectorDataType = VectorDataType.get( - fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) - ); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } - - // Skip index creation if no vectors or docs in segment - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); - return; - } - - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType); + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) throws IOException { + final VectorDataType vectorDataType = extractVectorDataType(field); + final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field)); if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - recordMergeStats(pair.docs.length, arraySize); - } - - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - - if (isRefresh) { - recordRefreshStats(); - } - - // Ensure engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper - state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); - writeFooter(indexPath, engineFileName); - } - - private void recordMergeStats(int length, long arraySize) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); - KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); - KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); - KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length); - KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); - } - - private void recordRefreshStats() { - KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); - } - - private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map<String, Object> parameters = new HashMap<>(); - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); - - AccessController.doPrivileged((PrivilegedAction<Void>) () -> { - JNIService.createIndexFromTemplate( - pair.docs, - pair.getVectorAddress(), - pair.getDimension(), - indexPath, - model.getModelBlob(), - parameters, - knnEngine - ); - return null; - }); - } - - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { - Map<String, Object> parameters = new HashMap<>(); - Map<String, String> fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(PARAMETERS); - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map<String, Object> algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); + NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues); } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); - } - - // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine - && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues); } - - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - // Pass the path for the nms library to save the file - AccessController.doPrivileged((PrivilegedAction<Void>) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); - return null; - }); } /** @@ -281,7 +92,7 @@ public void merge(MergeState mergeState) { if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { StopWatch stopWatch = new StopWatch(); stopWatch.start(); - addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true, false); + addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true); stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); @@ -317,52 +128,4 @@ public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) t public void close() throws IOException { delegatee.close(); } - - @FunctionalInterface - private interface NativeIndexCreator { - void createIndex() throws IOException; - } - - private void writeFooter(String indexPath, String engineFileName) throws IOException { - // Opens the engine file that was created and appends a footer to it. The footer consists of - // 1. A Footer magic number (int - 4 bytes) - // 2. A checksum algorithm id (int - 4 bytes) - // 3. A checksum (long - bytes) - // The checksum is computed on all the bytes written to the file up to that point. - // Logic where footer is written in Lucene can be found here: - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 - OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); - ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); - byteBuffer.putInt(FOOTER_MAGIC); - byteBuffer.putInt(0); - os.write(byteBuffer.array()); - os.flush(); - - ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); - checksumIndexInput.seek(checksumIndexInput.length()); - long value = checksumIndexInput.getChecksum(); - checksumIndexInput.close(); - - if (isChecksumValid(value)) { - throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); - } - - // Write the CRC checksum to the end of the OutputStream and close the stream - byteBuffer.putLong(0, value); - os.write(byteBuffer.array()); - os.close(); - } - - private boolean isChecksumValid(long value) { - // Check pulled from - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 - return (value & CRC32_CHECKSUM_SANITY) != 0; - } - - private VectorTransfer getVectorTransfer(VectorDataType vectorDataType) { - if (VectorDataType.BINARY == vectorDataType) { - return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index e4860af31..1abb84944 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -30,6 +30,7 @@ */ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngineFieldVectorsWriter.class); + @Getter private final FieldInfo fieldInfo; /** * We are using a map here instead of list, because for sampler interface for quantization we have to advance the iterator @@ -77,6 +78,8 @@ public void addValue(int docID, T vectorValue) { + "\" appears more than once in this document (only one value is allowed per field)" ); } + // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative + // graph build support on the JNI layer. assert docID > lastDocID; vectors.put(docID, vectorValue); docsWithField.add(docID); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index b81ec9789..8ae09222d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -12,23 +12,33 @@ package org.opensearch.knn.index.codec.KNN990Codec; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; + /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. */ +@Log4j2 @RequiredArgsConstructor public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); @@ -46,8 +56,6 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { @Override public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOException { final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream); - // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative - // graph build support on the JNI layer. fields.add(newField); return flatVectorsWriter.addField(fieldInfo, newField); } @@ -62,14 +70,42 @@ public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { // simply write data in the flat file flatVectorsWriter.flush(maxDoc, sortMap); - // TODO: add code for creating Vector datastructures during lucene flush operation + for (final NativeEngineFieldVectorsWriter<?> field : fields) { + final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo()); + final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + + // TODO: Extract quantization state here + NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues); + } } @Override public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // TODO: add code for creating Vector datastructures during merge operation + + // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + final KNNVectorValues<?> knnVectorValues; + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + break; + case BYTE: + final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + break; + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + + // TODO: Extract Quantization state here + NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java new file mode 100644 index 000000000..6df157c8d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapByteQuantizedVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapFloatVectorTransfer; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; + +/** + * Iteratively builds the index. + */ +final class MemOptimizedNativeIndexBuildStrategy implements NativeIndexBuildStrategy { + + private static MemOptimizedNativeIndexBuildStrategy INSTANCE = new MemOptimizedNativeIndexBuildStrategy(); + + public static MemOptimizedNativeIndexBuildStrategy getInstance() { + return INSTANCE; + } + + private MemOptimizedNativeIndexBuildStrategy() {} + + public void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException { + // Needed to make sure we dont get 0 dimensions while initializing index + knnVectorValues.init(); + KNNEngine engine = indexInfo.getKnnEngine(); + Map<String, Object> indexParameters = indexInfo.getParameters(); + + // Initialize the index + long indexMemoryAddress = AccessController.doPrivileged( + (PrivilegedAction<Long>) () -> JNIService.initIndexFromScratch( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + indexParameters, + engine + ) + ); + + try (final VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), knnVectorValues)) { + + while (vectorTransfer.hasNext()) { + vectorTransfer.transferBatch(); + long vectorAddress = vectorTransfer.getVectorAddress(); + int[] docs = vectorTransfer.getTransferredDocsIds(); + + // Insert vectors + AccessController.doPrivileged((PrivilegedAction<Void>) () -> { + JNIService.insertToIndex(docs, vectorAddress, knnVectorValues.dimension(), indexParameters, indexMemoryAddress, engine); + return null; + }); + } + + // Write vector + AccessController.doPrivileged((PrivilegedAction<Void>) () -> { + JNIService.writeIndex(indexInfo.getIndexPath(), indexMemoryAddress, engine, indexParameters); + return null; + }); + + } catch (Exception exception) { + throw new RuntimeException("Failed to build index", exception); + } + } + + // TODO: Will probably need a factory once quantization is added + private VectorTransfer getVectorTransfer(VectorDataType vectorDataType, KNNVectorValues<?> knnVectorValues) throws IOException { + switch (vectorDataType) { + case FLOAT: + return new OffHeapFloatVectorTransfer((KNNFloatVectorValues) knnVectorValues); + case BINARY: + case BYTE: + return new OffHeapByteQuantizedVectorTransfer<>((KNNVectorValues<byte[]>) knnVectorValues); + default: + throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java new file mode 100644 index 000000000..19475adfa --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; + +/** + * Interface which dictates how the index needs to be built + */ +public interface NativeIndexBuildStrategy { + + void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java new file mode 100644 index 000000000..92d29b9ba --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -0,0 +1,280 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.util.IndexUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.plugin.stats.KNNGraphValue; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; +import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + +/** + * Writes KNN Index for a field in a segment. This is intended to be used for native engines + */ +@AllArgsConstructor +@Log4j2 +public class NativeIndexWriter { + private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; + + private final SegmentWriteState state; + private final FieldInfo fieldInfo; + private final NativeIndexBuildStrategy indexBuilder; + // TODO: Add quantization state as a member variable + + /** + * Gets the correct writer type from fieldInfo + * + * @param fieldInfo + * @return correct NativeIndexWriter to make index specified in fieldInfo + */ + public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { + // TODO: Fetch the quantization state here and pass it to NativeIndexWriter + + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); + boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + if (iterative) { + return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance()); + } + return new NativeIndexWriter(state, fieldInfo, VectorTransferIndexBuildStrategy.getInstance()); + } + + /** + * flushes the index + * + * @param knnVectorValues + * @throws IOException + */ + public void flushIndex(final KNNVectorValues<?> knnVectorValues) throws IOException { + knnVectorValues.init(); + buildAndWriteIndex(knnVectorValues); + recordRefreshStats(); + } + + /** + * Merges kNN index + * @param knnVectorValues + * @throws IOException + */ + public void mergeIndex(final KNNVectorValues<?> knnVectorValues) throws IOException { + knnVectorValues.init(); + if (knnVectorValues.docId() == NO_MORE_DOCS) { + // This is in place so we do not add metrics + return; + } + + long arraySize = knnVectorValues.bytesPerVector(); + startMergeStats(knnVectorValues.dimension(), arraySize); + buildAndWriteIndex(knnVectorValues); + endMergeStats(knnVectorValues.dimension(), arraySize); + } + + private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws IOException { + if (knnVectorValues.totalLiveDocs() == 0) { + log.warn("No live docs for field " + fieldInfo.name); + return; + } + + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + final String engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getVersion(), + fieldInfo.name, + knnEngine.getExtension() + ); + final String indexPath = Paths.get( + ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), + engineFileName + ).toString(); + state.directory.createOutput(engineFileName, state.context).close(); + + final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine); + indexBuilder.buildAndWriteIndex(nativeIndexParams, knnVectorValues); + writeFooter(indexPath, engineFileName, state); + } + + // The logic for building parameters need to be cleaned up. There are various cases handled here + // Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type + // TODO: Refactor this so its scalable. Possibly move it out of this class + private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException { + final Map<String, Object> parameters; + final VectorDataType vectorDataType; + if (fieldInfo.attributes().containsKey(MODEL_ID)) { + Model model = getModel(fieldInfo); + vectorDataType = model.getModelMetadata().getVectorDataType(); + parameters = getTemplateParameters(fieldInfo, model); + } else { + vectorDataType = VectorDataType.get( + fieldInfo.attributes().getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + ); + parameters = getParameters(fieldInfo, vectorDataType, knnEngine); + } + + return BuildIndexParams.builder() + .parameters(parameters) + .vectorDataType(vectorDataType) + .knnEngine(knnEngine) + .indexPath(indexPath) + .build(); + } + + private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException { + Map<String, Object> parameters = new HashMap<>(); + Map<String, String> fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map<String, Object> algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine + && VectorDataType.BINARY.getValue().equals(vectorDataType.getValue()) + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + } + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + return parameters; + } + + private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException { + Map<String, Object> parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID)); + parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob()); + IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + return parameters; + } + + private Model getModel(FieldInfo fieldInfo) { + String modelId = fieldInfo.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + return model; + } + + private void startMergeStats(int numDocs, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); + KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); + } + + private void endMergeStats(int numDocs, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); + KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); + } + + private void recordRefreshStats() { + KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); + } + + private boolean isChecksumValid(long value) { + // Check pulled from + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 + return (value & CRC32_CHECKSUM_SANITY) != 0; + } + + private void writeFooter(String indexPath, String engineFileName, SegmentWriteState state) throws IOException { + // Opens the engine file that was created and appends a footer to it. The footer consists of + // 1. A Footer magic number (int - 4 bytes) + // 2. A checksum algorithm id (int - 4 bytes) + // 3. A checksum (long - bytes) + // The checksum is computed on all the bytes written to the file up to that point. + // Logic where footer is written in Lucene can be found here: + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 + OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); + ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); + byteBuffer.putInt(FOOTER_MAGIC); + byteBuffer.putInt(0); + os.write(byteBuffer.array()); + os.flush(); + + ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); + checksumIndexInput.seek(checksumIndexInput.length()); + long value = checksumIndexInput.getChecksum(); + checksumIndexInput.close(); + + if (isChecksumValid(value)) { + throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); + } + + // Write the CRC checksum to the end of the OutputStream and close the stream + byteBuffer.putLong(0, value); + os.write(byteBuffer.array()); + os.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java new file mode 100644 index 000000000..9ea61cffe --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapByteQuantizedVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapFloatVectorTransfer; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; + +/** + * Transfers all vectors to offheap and then builds an index + */ +final class VectorTransferIndexBuildStrategy implements NativeIndexBuildStrategy { + + private static VectorTransferIndexBuildStrategy INSTANCE = new VectorTransferIndexBuildStrategy(); + + public static VectorTransferIndexBuildStrategy getInstance() { + return INSTANCE; + } + + private VectorTransferIndexBuildStrategy() {} + + public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException { + // iterating it once to be safe + knnVectorValues.init(); + try (final VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), knnVectorValues)) { + vectorTransfer.transferBatch(); + assert !vectorTransfer.hasNext(); + + final Map<String, Object> params = indexInfo.getParameters(); + // Currently this is if else as there are only two cases, with more cases this will have to be made + // more maintainable + if (params.containsKey(MODEL_ID)) { + AccessController.doPrivileged((PrivilegedAction<Void>) () -> { + JNIService.createIndexFromTemplate( + vectorTransfer.getTransferredDocsIds(), + vectorTransfer.getVectorAddress(), + knnVectorValues.dimension(), + indexInfo.getIndexPath(), + (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } else { + AccessController.doPrivileged((PrivilegedAction<Void>) () -> { + JNIService.createIndex( + vectorTransfer.getTransferredDocsIds(), + vectorTransfer.getVectorAddress(), + knnVectorValues.dimension(), + indexInfo.getIndexPath(), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } + + } catch (Exception exception) { + throw new RuntimeException("Failed to build index", exception); + } + } + + private VectorTransfer getVectorTransfer(VectorDataType vectorDataType, KNNVectorValues<?> knnVectorValues) throws IOException { + switch (vectorDataType) { + case FLOAT: + return new OffHeapFloatVectorTransfer((KNNFloatVectorValues) knnVectorValues, knnVectorValues.totalLiveDocs()); + case BINARY: + case BYTE: + return new OffHeapByteQuantizedVectorTransfer<>((KNNVectorValues<byte[]>) knnVectorValues, knnVectorValues.totalLiveDocs()); + default: + throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java new file mode 100644 index 000000000..d1a0645ca --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.model; + +import lombok.Builder; +import lombok.Value; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.Map; + +@Value +@Builder +public class BuildIndexParams { + KNNEngine knnEngine; + String indexPath; + VectorDataType vectorDataType; + Map<String, Object> parameters; + // TODO: Add quantization state as parameter to build index +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java new file mode 100644 index 000000000..ff346a810 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer quantized binary vectors to off heap memory + * The reason this is different from {@link OffHeapByteQuantizedVectorTransfer} is because of allocation and deallocation + * of memory on JNI layer. Use this if unsigned int is needed on JNI layer + */ +public final class OffHeapBinaryQuantizedVectorTransfer<T> extends OffHeapQuantizedVectorTransfer<T, byte[]> { + + public OffHeapBinaryQuantizedVectorTransfer(KNNVectorValues<T> vectorValues, Long batchSize) { + super(vectorValues, batchSize, (vector, state) -> (byte[]) vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); + } + + public OffHeapBinaryQuantizedVectorTransfer(KNNVectorValues<T> vectorValues) { + this(vectorValues, null); + } + + @Override + public void close() { + super.close(); + // TODO: deallocate the memory location + } + + @Override + protected long transfer(List<byte[]> vectorsToTransfer, boolean append) throws IOException { + // TODO: call to JNIService to transfer vector + return 0L; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java new file mode 100644 index 000000000..e5c8d3e12 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer quantized byte vectors to off heap memory. + * The reason this is different from {@link OffHeapBinaryQuantizedVectorTransfer} is because of allocation and deallocation + * of memory on JNI layer. Use this if signed int is needed on JNI layer + */ +public final class OffHeapByteQuantizedVectorTransfer<T> extends OffHeapQuantizedVectorTransfer<T, byte[]> { + + public OffHeapByteQuantizedVectorTransfer(KNNVectorValues<T> vectorValues, final Long batchSize) throws IOException { + super(vectorValues, batchSize, (vector, state) -> (byte[]) vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); + } + + public OffHeapByteQuantizedVectorTransfer(KNNVectorValues<T> vectorValues) throws IOException { + this(vectorValues, null); + } + + @Override + protected long transfer(List<byte[]> batch, boolean append) throws IOException { + return JNICommons.storeByteVectorData(getVectorAddress(), batch.toArray(new byte[][] {}), batchSize * batch.get(0).length, append); + } + + @Override + public void close() { + super.close(); + JNICommons.freeByteVectorData(getVectorAddress()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java new file mode 100644 index 000000000..66246494e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer float vectors to off heap memory. + */ +public final class OffHeapFloatVectorTransfer extends OffHeapQuantizedVectorTransfer<float[], float[]> { + + public OffHeapFloatVectorTransfer(KNNFloatVectorValues vectorValues, Long batchSize) throws IOException { + super(vectorValues, batchSize, (vector, state) -> vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); + } + + public OffHeapFloatVectorTransfer(KNNFloatVectorValues vectorValues) throws IOException { + this(vectorValues, null); + } + + @Override + protected long transfer(final List<float[]> vectorsToTransfer, boolean append) throws IOException { + return JNICommons.storeVectorData( + getVectorAddress(), + vectorsToTransfer.toArray(new float[][] {}), + this.batchSize * vectorsToTransfer.get(0).length, + append + ); + } + + @Override + public void close() { + super.close(); + JNICommons.freeVectorData(getVectorAddress()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java new file mode 100644 index 000000000..7ddd93620 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.Getter; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.List; +import java.util.function.BiFunction; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.KNNVectorUtil.createArrayList; + +/** + * The class is intended to transfer {@link KNNVectorValues} to off heap memory. If also provides and ability to quantize the vector + * before it is transferred to offHeap memory. The class is not thread safe + * + * @param <T> an array of primitive type + * @param <V> an array of primitive type after being quantized + */ +abstract class OffHeapQuantizedVectorTransfer<T, V> implements VectorTransfer { + + protected static final int DEFAULT_COMPRESSION_FACTOR = 1; + + @Getter + private long vectorAddress; + @Getter + private int[] transferredDocsIds; + private final int transferLimit; + // Keeping this as a member variable as this should not be changed considering the vector address is reused between batches + protected long batchSize; + + private final List<V> vectorsToTransfer; + private final List<Integer> transferredDocIdsList; + + private final KNNVectorValues<T> vectorValues; + + // TODO: Replace with actual quantization parameters + private final BiFunction<T, String, V> quantizer; + private final String quantizationState; + + public OffHeapQuantizedVectorTransfer( + final KNNVectorValues<T> vectorValues, + final Long batchSize, + final BiFunction<T, String, V> quantizer, + final String quantizationState, + final int compressionFactor + ) { + assert vectorValues.docId() != -1 : "vectorValues docId must be set, iterate it once for vector transfer to succeed"; + assert vectorValues.docId() != NO_MORE_DOCS : "vectorValues already iterated, Nothing to transfer"; + + this.quantizer = quantizer; + this.quantizationState = quantizationState; + this.transferLimit = (int) Math.max( + 1, + (int) KNNSettings.getVectorStreamingMemoryLimit().getBytes() / (vectorValues.bytesPerVector() / compressionFactor) + ); + this.batchSize = batchSize == null ? transferLimit : batchSize; + this.vectorsToTransfer = createArrayList(this.batchSize); + this.transferredDocIdsList = createArrayList(this.batchSize); + this.vectorValues = vectorValues; + this.vectorAddress = 0; // we can allocate initial memory here, currently storeVectorData takes care of it + } + + @Override + public void transferBatch() throws IOException { + if (vectorValues.docId() == NO_MORE_DOCS) { + // Throwing instead of returning so there is no way client can go into an infinite loop + throw new IllegalStateException("No more vectors available to transfer"); + } + + assert vectorsToTransfer.isEmpty() : "Last batch wasn't transferred"; + assert transferredDocIdsList.isEmpty() : "Last batch wasn't transferred"; + + int totalDocsTransferred = 0; + boolean freshBatch = true; + + // TODO: Create non-final QuantizationOutput once here and then reuse the output + while (vectorValues.docId() != NO_MORE_DOCS && totalDocsTransferred < batchSize) { + V quantizedVector = quantizer.apply(vectorValues.conditionalCloneVector(), quantizationState); + + transferredDocIdsList.add(vectorValues.docId()); + vectorsToTransfer.add(quantizedVector); + if (vectorsToTransfer.size() == transferLimit) { + vectorAddress = transfer(vectorsToTransfer, !freshBatch); + vectorsToTransfer.clear(); + freshBatch = false; + } + vectorValues.nextDoc(); + totalDocsTransferred++; + } + + // Handle batchSize < transferLimit + if (!vectorsToTransfer.isEmpty()) { + vectorAddress = transfer(vectorsToTransfer, !freshBatch); + vectorsToTransfer.clear(); + } + + this.transferredDocsIds = new int[transferredDocIdsList.size()]; + for (int i = 0; i < transferredDocIdsList.size(); i++) { + transferredDocsIds[i] = transferredDocIdsList.get(i); + } + transferredDocIdsList.clear(); + } + + @Override + public boolean hasNext() { + return vectorValues.docId() != NO_MORE_DOCS; + } + + @Override + public void close() { + transferredDocIdsList.clear(); + transferredDocsIds = null; + vectorAddress = 0; + } + + protected abstract long transfer(final List<V> vectorsToTransfer, final boolean append) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java index c23bd4317..fd76a7861 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java @@ -5,50 +5,37 @@ package org.opensearch.knn.index.codec.transfer; -import lombok.Data; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; +import java.io.Closeable; +import java.io.IOException; /** - * Abstract class to transfer vector value from Java to native memory + * An interface to transfer vectors from one memory location to another + * Class is Closeable to be able to release memory once done */ -@Data -public abstract class VectorTransfer { - protected final long vectorsStreamingMemoryLimit; - protected long totalLiveDocs; - protected long vectorsPerTransfer; - protected long vectorAddress; - protected int dimension; - - public VectorTransfer(final long vectorsStreamingMemoryLimit) { - this.vectorsStreamingMemoryLimit = vectorsStreamingMemoryLimit; - this.vectorsPerTransfer = Integer.MIN_VALUE; - } +public interface VectorTransfer extends Closeable { /** - * Initialize the transfer - * - * @param totalLiveDocs total number of vectors to be transferred + * Transfer a batch of vectors from one location to another + * The batch size here is intended to be constant for multiple transfers so should be encapsulated in the + * implementation. A new batch size should require another instance + * @throws IOException */ - abstract public void init(final long totalLiveDocs); + void transferBatch() throws IOException; /** - * Transfer a single vector - * - * @param bytesRef a vector in bytes format + * Indicates if there are more vectors to transfer + * @return */ - abstract public void transfer(final BytesRef bytesRef); + boolean hasNext(); /** - * Close the transfer + * Gives the docIds for transfered vectors + * @return */ - abstract public void close(); + int[] getTransferredDocsIds(); /** - * Get serialization mode of given byte stream - * - * @param bytesRef bytes of a vector - * @return serialization mode + * @return the memory address of the vectors transferred */ - abstract public SerializationMode getSerializationMode(final BytesRef bytesRef); + long getVectorAddress(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java deleted file mode 100644 index e81ac35fc..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.util.ArrayList; -import java.util.List; - -/** - * Vector transfer for byte - */ -public class VectorTransferByte extends VectorTransfer { - private List<byte[]> vectorList; - - public VectorTransferByte(final long vectorsStreamingMemoryLimit) { - super(vectorsStreamingMemoryLimit); - vectorList = new ArrayList<>(); - } - - @Override - public void init(final long totalLiveDocs) { - this.totalLiveDocs = totalLiveDocs; - vectorList.clear(); - } - - @Override - public void transfer(final BytesRef bytesRef) { - dimension = bytesRef.length * 8; - if (vectorsPerTransfer == Integer.MIN_VALUE) { - // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with length of 5, then per - // transfer we have to send 100/5 => 20 vectors. - vectorsPerTransfer = vectorsStreamingMemoryLimit / bytesRef.length; - // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that - // we are sending minimum number of vectors. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = 1; - } - } - - vectorList.add(ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length)); - if (vectorList.size() == vectorsPerTransfer) { - transfer(); - } - } - - @Override - public void close() { - transfer(); - } - - @Override - public SerializationMode getSerializationMode(final BytesRef bytesRef) { - return SerializationMode.COLLECTIONS_OF_BYTES; - } - - private void transfer() { - int lengthOfVector = dimension / 8; - vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * lengthOfVector); - vectorList.clear(); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java deleted file mode 100644 index a9c792398..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.util.ArrayList; -import java.util.List; - -/** - * Vector transfer for float - */ -public class VectorTransferFloat extends VectorTransfer { - private List<float[]> vectorList; - - public VectorTransferFloat(final long vectorsStreamingMemoryLimit) { - super(vectorsStreamingMemoryLimit); - vectorList = new ArrayList<>(); - } - - @Override - public void init(final long totalLiveDocs) { - this.totalLiveDocs = totalLiveDocs; - vectorList.clear(); - } - - @Override - public void transfer(final BytesRef bytesRef) { - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(bytesRef); - final float[] vector = vectorSerializer.byteToFloatArray(bytesRef); - dimension = vector.length; - - if (vectorsPerTransfer == Integer.MIN_VALUE) { - // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with 5 dimension, then per - // transfer we have to send 100/(5 * 4) => 5 vectors. - vectorsPerTransfer = vectorsStreamingMemoryLimit / ((long) dimension * Float.BYTES); - // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that - // we are sending minimum number of vectors. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = 1; - } - } - - vectorList.add(vector); - if (vectorList.size() == vectorsPerTransfer) { - transfer(); - } - } - - @Override - public void close() { - transfer(); - } - - @Override - public SerializationMode getSerializationMode(final BytesRef bytesRef) { - return KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef); - } - - private void transfer() { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); - vectorList.clear(); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index ea14fe883..51100a1e0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,63 +5,14 @@ package org.opensearch.knn.index.codec.util; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - @AllArgsConstructor - public static final class Pair { - public int[] docs; - @Getter - @Setter - private long vectorAddress; - @Getter - @Setter - private int dimension; - public SerializationMode serializationMode; - } - - /** - * Extract docIds and vectors from binary doc values. - * - * @param values Binary doc values - * @param vectorTransfer Utility to make transfer - * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors - * @throws IOException thrown when unable to get binary of vectors - */ - public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException { - List<Integer> docIdList = new ArrayList<>(); - SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - vectorTransfer.init(getTotalLiveDocsCount(values)); - for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { - BytesRef bytesref = values.binaryValue(); - serializationMode = vectorTransfer.getSerializationMode(bytesref); - vectorTransfer.transfer(bytesref); - docIdList.add(doc); - } - vectorTransfer.close(); - return new KNNCodecUtil.Pair( - docIdList.stream().mapToInt(Integer::intValue).toArray(), - vectorTransfer.getVectorAddress(), - vectorTransfer.getDimension(), - serializationMode - ); - } - /** * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. * @param numVectors number of vectors in the array diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java index f38099b74..5da093fd5 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.ByteVectorValues; import java.io.IOException; +import java.util.Arrays; /** * Concrete implementation of {@link KNNVectorValues} that returns byte[] as vector where binary vector is stored and @@ -25,17 +26,17 @@ public class KNNBinaryVectorValues extends KNNVectorValues<byte[]> { @Override public byte[] getVector() throws IOException { final byte[] vector = VectorValueExtractorStrategy.extractBinaryVector(vectorValuesIterator); - this.dimension = vector.length; + this.dimension = vector.length * Byte.SIZE; + this.bytesPerVector = vector.length; return vector; } - /** - * Binary Vector values gets stored as byte[], hence for dimension of the binary vector we have to multiply the - * byte[] size with {@link Byte#SIZE} - * @return int - */ @Override - public int dimension() { - return super.dimension() * Byte.SIZE; + public byte[] conditionalCloneVector() throws IOException { + byte[] vector = getVector(); + if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + return Arrays.copyOf(vector, vector.length); + } + return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java index ccbbfab77..1ebc50970 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.ByteVectorValues; import java.io.IOException; +import java.util.Arrays; /** * Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over @@ -26,6 +27,17 @@ public class KNNByteVectorValues extends KNNVectorValues<byte[]> { public byte[] getVector() throws IOException { final byte[] vector = VectorValueExtractorStrategy.extractByteVector(vectorValuesIterator); this.dimension = vector.length; + this.bytesPerVector = vector.length; + return vector; + } + + @Override + public byte[] conditionalCloneVector() throws IOException { + byte[] vector = getVector(); + if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + return Arrays.copyOf(vector, vector.length); + + } return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index 174f3a89e..d11739ee6 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.FloatVectorValues; import java.io.IOException; +import java.util.Arrays; /** * Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over @@ -24,6 +25,17 @@ public class KNNFloatVectorValues extends KNNVectorValues<float[]> { public float[] getVector() throws IOException { final float[] vector = VectorValueExtractorStrategy.extractFloatVector(vectorValuesIterator); this.dimension = vector.length; + this.bytesPerVector = vector.length * 4L; + return vector; + } + + @Override + public float[] conditionalCloneVector() throws IOException { + float[] vector = getVector(); + if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { + return Arrays.copyOf(vector, vector.length); + } + return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index c4ed64bc2..c2da8cde1 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -23,11 +23,24 @@ public abstract class KNNVectorValues<T> { protected final KNNVectorValuesIterator vectorValuesIterator; protected int dimension; + protected long bytesPerVector; protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { this.vectorValuesIterator = vectorValuesIterator; } + /** + * Iterates the values once only if docIds is not at start position + * Also populates dimension and bytesPerVector in the process + * @throws IOException + */ + public void init() throws IOException { + if (docId() == -1) { + nextDoc(); + getVector(); + } + } + /** * Return a vector reference. If you are adding this address in a List/Map ensure that you are copying the vector first. * This is to ensure that we keep the heap and latency in check by reducing the copies of vectors. @@ -37,6 +50,19 @@ protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { */ public abstract T getVector() throws IOException; + /** + * Intended to return a vector reference either after deep copy of the vector obtained from {@code getVector} + * or return the vector itself + * + * This decision to clone depends on the vector returned based on the type of iterator + * + * @return T an array of byte[], float[] Or a deep copy of it + * @throws IOException + */ + public T conditionalCloneVector() throws IOException { + return getVector(); + } + /** * Dimension of vector is returned. Do call getVector function first before calling this function otherwise you will get 0 value. * @return int @@ -46,6 +72,15 @@ public int dimension() { return dimension; } + /** + * Size of a vector in bytes is returned. Do call getVector function first before calling this function otherwise you will get 0 value. + * @return int + */ + public long bytesPerVector() { + assert docId() != -1 && bytesPerVector != 0 : "Cannot get bytesPerVector before we retrieve a vector from KNNVectorValues"; + return bytesPerVector; + } + /** * Returns the total live docs for KNNVectorValues. * @return long @@ -81,5 +116,4 @@ public int advance(int docId) throws IOException { public int nextDoc() throws IOException { return vectorValuesIterator.nextDoc(); } - } diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index 0f5a049fc..29eb307e0 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -15,9 +15,14 @@ import java.util.Locale; +import lombok.experimental.UtilityClass; + +import java.util.Locale; + /** * A utility class for models. */ +@UtilityClass public class ModelUtil { public static void blockCommasInModelDescription(String description) { @@ -48,7 +53,7 @@ public static ModelMetadata getModelMetadata(final String modelId) { } final Model model = ModelCache.getInstance().get(modelId); final ModelMetadata modelMetadata = model.getModelMetadata(); - if (ModelUtil.isModelCreated(modelMetadata) == false) { + if (!ModelUtil.isModelCreated(modelMetadata)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); } return modelMetadata; diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 4f57b616a..a402be1f3 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -50,32 +50,70 @@ class FaissService { } /** - * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initIndex(long numDocs, int dim, Map<String, Object> parameters); + + /** + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initBinaryIndex(long numDocs, int dim, Map<String, Object> parameters); + + /** + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer - * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this - * <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a> + * created the memory address and that should only free up the memory. * - * @param ids array of ids mapping to the data passed in + * @param ids ids of documents * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to - * @param parameters parameters to build index + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion */ - public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map<String, Object> parameters); + public static native void insertToIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); /** - * Create a binary index for the native library The memory occupied by the vectorsAddress will be freed up during the + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer - * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this - * <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a> + * created the memory address and that should only free up the memory. * - * @param ids array of ids mapping to the data passed in + * @param ids ids of documents * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion + */ + public static native void insertToBinaryIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); + + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored + * @param indexPath path to save index file to + */ + public static native void writeIndex(long indexAddress, String indexPath); + + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored * @param indexPath path to save index file to - * @param parameters parameters to build index */ - public static native void createBinaryIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map<String, Object> parameters); + public static native void writeBinaryIndex(long indexAddress, String indexPath); /** * Create an index for the native library with a provided template index diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index 31a8f43cc..c7222738e 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -36,16 +36,59 @@ public class JNICommons { * will throw Exception. * * <p> - * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can - * lead to data corruption. + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. * </p> * - * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D float array containing data to be stored in native memory. + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address where the data is stored. */ - public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + public static long storeVectorData(long memoryAddress, float[][] data, long initialCapacity) { + return storeVectorData(memoryAddress, data, initialCapacity, true); + } + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * <p> + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + * </p> + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location + * @return memory address where the data is stored. + */ + public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity, boolean append); + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * <p> + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + * </p> + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity) { + return storeByteVectorData(memoryAddress, data, initialCapacity, true); + } /** * This is utility function that can be used to store data in native memory. This function will allocate memory for @@ -55,24 +98,25 @@ public class JNICommons { * will throw Exception. * * <p> - * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can - * lead to data corruption. + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. * </p> * - * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D byte array containing data to be stored in native memory. + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location * @return memory address where the data is stored. */ - public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity); + public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity, boolean append); /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * * <p> - * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. * </p> * * @param memoryAddress address to be freed. @@ -81,11 +125,11 @@ public class JNICommons { /** * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * * <p> - * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. * </p> * * @param memoryAddress address to be freed. diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index de696b5ce..7427af674 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -13,28 +13,110 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.common.Nullable; -import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.query.KNNQueryResult; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryResult; +import org.opensearch.knn.index.util.IndexUtil; +import java.util.Locale; import java.util.Map; /** * Service to distribute requests to the proper engine jni service */ public class JNIService { + /** + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param knnEngine knn engine + * @return address of the index in memory + */ + public static long initIndexFromScratch(long numDocs, int dim, Map<String, Object> parameters, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.initBinaryIndex(numDocs, dim, parameters); + } else { + return FaissService.initIndex(numDocs, dim, parameters); + } + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "initIndexFromScratch not supported for provided engine : %s", knnEngine.getName()) + ); + } + + /** + * Inserts to a faiss index. + * + * @param docs ids of documents + * @param vectorsAddress address of native memory where vectors are stored + * @param dimension dimension of the vector to be indexed + * @param parameters parameters to build index + * @param indexAddress address of native memory where index is stored + * @param knnEngine knn engine + */ + public static void insertToIndex( + int[] docs, + long vectorsAddress, + int dimension, + Map<String, Object> parameters, + long indexAddress, + KNNEngine knnEngine + ) { + int threadCount = (int) parameters.getOrDefault(KNNConstants.INDEX_THREAD_QTY, 0); + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.insertToBinaryIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } else { + FaissService.insertToIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } + return; + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "insertToIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + + /** + * Writes a faiss index to disk. + * + * @param indexPath path to save index to + * @param indexAddress address of native memory where index is stored + * @param knnEngine knn engine + * @param parameters parameters to build index + */ + public static void writeIndex(String indexPath, long indexAddress, KNNEngine knnEngine, Map<String, Object> parameters) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.writeBinaryIndex(indexAddress, indexPath); + } else { + FaissService.writeIndex(indexAddress, indexPath); + } + return; + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "writeIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this * <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a> * - * @param ids array of ids mapping to the data passed in + * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored - * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to - * @param parameters parameters to build index - * @param knnEngine engine to build index for + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param parameters parameters to build index + * @param knnEngine engine to build index for */ public static void createIndex( int[] ids, @@ -50,28 +132,21 @@ public static void createIndex( return; } - if (KNNEngine.FAISS == knnEngine) { - if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { - FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); - } else { - FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); - } - return; - } - - throw new IllegalArgumentException(String.format("CreateIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "CreateIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** * Create an index for the native library with a provided template index * - * @param ids array of ids mapping to the data passed in + * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored - * @param dim dimension of vectors to be indexed - * @param indexPath path to save index file to - * @param templateIndex empty template index - * @param parameters parameters to build index - * @param knnEngine engine to build index for + * @param dim dimension of vectors to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters parameters to build index + * @param knnEngine engine to build index for */ public static void createIndexFromTemplate( int[] ids, @@ -93,7 +168,7 @@ public static void createIndexFromTemplate( } throw new IllegalArgumentException( - String.format("CreateIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "CreateIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) ); } @@ -118,7 +193,9 @@ public static long loadIndex(String indexPath, Map<String, Object> parameters, K } } - throw new IllegalArgumentException(String.format("LoadIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "LoadIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -150,7 +227,7 @@ public static long initSharedIndexState(long indexAddr, KNNEngine knnEngine) { return FaissService.initSharedIndexState(indexAddr); } throw new IllegalArgumentException( - String.format("InitSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "InitSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -168,20 +245,20 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr, } throw new IllegalArgumentException( - String.format("SetSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "SetSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } /** * Query an index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param methodParameters method parameter - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryIndex( @@ -216,19 +293,21 @@ public static KNNQueryResult[] queryIndex( } return FaissService.queryIndex(indexPointer, queryVector, k, methodParameters, parentIds); } - throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "QueryIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** * Query a binary index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param methodParameters method parameter - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryBinaryIndex( @@ -252,7 +331,9 @@ public static KNNQueryResult[] queryBinaryIndex( parentIds ); } - throw new IllegalArgumentException(String.format("QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -283,7 +364,7 @@ public static void free(final long indexPointer, final KNNEngine knnEngine, fina return; } - throw new IllegalArgumentException(String.format("Free not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Free not supported for provided engine : %s", knnEngine.getName())); } /** @@ -298,7 +379,7 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE return; } throw new IllegalArgumentException( - String.format("FreeSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "FreeSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -319,17 +400,19 @@ public static byte[] trainIndex(Map<String, Object> indexParameters, int dimensi return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } - throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "TrainIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** * <p> - * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * </p> * Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param trainingData data to be transferred + * @param trainingData data to be transferred * @return pointer to native memory location of training data */ @Deprecated(since = "2.14.0", forRemoval = true) @@ -340,15 +423,15 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) /** * Range search index for a given query vector * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param radius search within radius threshold - * @param methodParameters parameters to be used when loading index - * @param knnEngine engine to query index + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param methodParameters parameters to be used when loading index + * @param knnEngine engine to query index * @param indexMaxResultWindow maximum number of results to return - * @param filteredIds list of doc ids to include in the query result - * @param filterIdsType how to filter ids: Batch or BitMap - * @param parentIds parent ids of the vectors + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType how to filter ids: Batch or BitMap + * @param parentIds parent ids of the vectors * @return KNNQueryResult array of neighbors within radius */ public static KNNQueryResult[] radiusQueryIndex( @@ -377,6 +460,6 @@ public static KNNQueryResult[] radiusQueryIndex( } return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, methodParameters, indexMaxResultWindow, parentIds); } - throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); + throw new IllegalArgumentException(String.format(Locale.ROOT, "RadiusQueryIndex not supported for provided engine")); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index e87531561..277211ae6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -104,7 +104,7 @@ public void testAddBinaryField_withKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; } }; @@ -141,7 +141,7 @@ public void testAddBinaryField_withoutKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, state) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; } }; @@ -159,7 +159,6 @@ public void testAddKNNBinaryField_noVectors() throws IOException { 128 ); Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); - Long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); Long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); Long initialMergeSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); Long initialMergeDocs = KNNGraphValue.MERGE_TOTAL_DOCS.getValue(); @@ -177,9 +176,8 @@ public void testAddKNNBinaryField_noVectors() throws IOException { SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true); assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); - assertEquals(initialRefreshOperations, KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeOperations, KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeSize, KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); assertEquals(initialMergeDocs, KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); @@ -223,7 +221,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -232,7 +229,61 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + } + + public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") + .addAttribute(KNNConstants.HNSW_ALGO_M, "16") + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -245,7 +296,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -290,7 +340,6 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -298,7 +347,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, false); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -312,9 +361,6 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException { @@ -357,7 +403,6 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -366,7 +411,7 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -379,7 +424,6 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException assertBinaryIndexLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension, dataType); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -456,7 +500,6 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -465,7 +508,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -478,7 +521,6 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -550,6 +592,6 @@ public void testAddBinaryField_luceneEngine_noInvocations_addKNNBinary() throws knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); - verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false), eq(true)); + verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 322b714f2..feb6b9374 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -46,6 +46,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; @@ -99,10 +100,13 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc byte[] byteVector = { 6, 14 }; addFieldToIndex( - new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32)), + new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT)), + indexWriter + ); + addFieldToIndex( + new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE, VectorDataType.BYTE)), indexWriter ); - addFieldToIndex(new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE)), indexWriter); final IndexReader indexReader = indexWriter.getReader(); // ensuring segments are created indexWriter.flush(); @@ -187,17 +191,19 @@ private void addFieldToIndex(final Field vectorField, final RandomIndexWriter in indexWriter.addDocument(doc1); } - private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding) { + private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding, VectorDataType vectorDataType) { FieldType nativeVectorField = new FieldType(); // TODO: Replace this with the default field which will be created in mapper for Native Engines with KNNVectorsFormat nativeVectorField.setTokenized(false); nativeVectorField.setIndexOptions(IndexOptions.NONE); nativeVectorField.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); nativeVectorField.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); - nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); + nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()); nativeVectorField.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + nativeVectorField.putAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + nativeVectorField.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); nativeVectorField.setVectorAttributes( dimension, vectorEncoding, diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java new file mode 100644 index 000000000..f4aed1049 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.SneakyThrows; +import org.apache.lucene.index.DocsWithFieldSet; +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.util.Map; + +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; + +public class OffHeapVectorTransferTests extends KNNTestCase { + + @Mock + ClusterSettings clusterSettings; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + KNNSettings.state().setClusterService(clusterService); + } + + @SneakyThrows + public void testFloatTransfer() { + // Given + when(clusterSettings.get(KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING)).thenReturn(new ByteSizeValue(16)); + final Map<Integer, float[]> docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + //Transfer 1 vector + KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + VectorTransfer vectorTransfer; + + //Transfer batch, limit == batch size + knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0, 1 }, { 2 } }, 2); + + //Transfer batch, limit < batch size + knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues, 5L); + vectorTransfer.transferBatch(); + assertNotEquals(0, vectorTransfer.getVectorAddress()); + assertArrayEquals(new int[] {0, 1, 2}, vectorTransfer.getTransferredDocsIds()); + + //Transfer batch, limit > batch size + knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues, 1L); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0 }, { 1 }, { 2 } }, 3); + } + + @SneakyThrows + public void testByteTransfer() { + // Given + when(clusterSettings.get(KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING)).thenReturn(new ByteSizeValue(4)); + final Map<Integer, byte[]> docs = Map.of(0, new byte[] { 1, 2 }, 1, new byte[] { 2, 3 }, 2, new byte[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + //Transfer 1 vector + KNNVectorValues<byte[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + VectorTransfer vectorTransfer; + + //Transfer batch, limit == batch size + knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapByteQuantizedVectorTransfer<>(knnVectorValues); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0, 1 }, { 2 } }, 2); + + //Transfer batch, limit < batch size + knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapByteQuantizedVectorTransfer<>(knnVectorValues, 5L); + vectorTransfer.transferBatch(); + assertNotEquals(0, vectorTransfer.getVectorAddress()); + assertArrayEquals(new int[] {0, 1, 2}, vectorTransfer.getTransferredDocsIds()); + + //Transfer batch, limit > batch size + knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapByteQuantizedVectorTransfer<>(knnVectorValues, 1L); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0 }, { 1 }, { 2 } }, 3); + } + + // TODO: Add a unit test for binary + + @SneakyThrows + private void testTransferBatchVectors(VectorTransfer vectorTransfer, int[][] expectedDocIds, int expectedIterations) { + long vectorAddress = 0L; + try { + int iteration = 0; + while (vectorTransfer.hasNext()) { + vectorTransfer.transferBatch(); + if (iteration != 0) { + assertEquals("Vector address shouldn't be different", vectorAddress, vectorTransfer.getVectorAddress()); + } else { + assertEquals(0, vectorAddress); + vectorAddress = vectorTransfer.getVectorAddress(); + } + assertArrayEquals(expectedDocIds[iteration], vectorTransfer.getTransferredDocsIds()); + iteration++; + } + assertEquals(expectedIterations, iteration); + } finally { + vectorTransfer.close(); + assertEquals(vectorTransfer.getVectorAddress(), 0); + assertNull(vectorTransfer.getTransferredDocsIds()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java deleted file mode 100644 index 2f091a035..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.io.IOException; -import java.util.Random; - -import static org.junit.Assert.assertNotEquals; - -public class VectorTransferByteTests extends TestCase { - @SneakyThrows - public void testTransfer_whenCalled_thenAdded() { - final BytesRef bytesRef1 = getByteArrayOfVectors(20); - final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(40); - try { - vectorTransfer.init(2); - - vectorTransfer.transfer(bytesRef1); - // flush is not called - assertEquals(0, vectorTransfer.getVectorAddress()); - - vectorTransfer.transfer(bytesRef2); - // flush should be called - assertNotEquals(0, vectorTransfer.getVectorAddress()); - } finally { - if (vectorTransfer.getVectorAddress() != 0) { - JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); - } - } - } - - @SneakyThrows - public void testSerializationMode_whenCalled_thenReturn() { - final BytesRef bytesRef = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(1000); - - // Verify - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, vectorTransfer.getSerializationMode(bytesRef)); - } - - private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException { - byte[] vector = new byte[vectorLength]; - new Random().nextBytes(vector); - return new BytesRef(vector); - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java deleted file mode 100644 index 620fd7c65..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.jni.JNICommons; - -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.Random; -import java.util.stream.IntStream; - -import static org.junit.Assert.assertNotEquals; - -public class VectorTransferFloatTests extends TestCase { - @SneakyThrows - public void testTransfer_whenCalled_thenAdded() { - final BytesRef bytesRef1 = getByteArrayOfVectors(20); - final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(160); - try { - vectorTransfer.init(2); - - vectorTransfer.transfer(bytesRef1); - // flush is not called - assertEquals(0, vectorTransfer.getVectorAddress()); - - vectorTransfer.transfer(bytesRef2); - // flush should be called - assertNotEquals(0, vectorTransfer.getVectorAddress()); - } finally { - if (vectorTransfer.getVectorAddress() != 0) { - JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); - } - } - } - - @SneakyThrows - public void testSerializationMode_whenCalled_thenReturn() { - final BytesRef bytesRef = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(1000); - - // Verify - assertEquals(KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef), vectorTransfer.getSerializationMode(bytesRef)); - } - - private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException { - float[] vector = new float[vectorLength]; - IntStream.range(0, vectorLength).forEach(index -> vector[index] = new Random().nextFloat()); - - final ByteArrayOutputStream bas = new ByteArrayOutputStream(); - final DataOutputStream ds = new DataOutputStream(bas); - for (float f : vector) { - ds.writeFloat(f); - } - final byte[] vectorAsCollectionOfFloats = bas.toByteArray(); - return new BytesRef(vectorAsCollectionOfFloats); - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 47dd1dda9..dbea6375b 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -6,54 +6,11 @@ package org.opensearch.knn.index.codec.util; import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import java.util.Arrays; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { - @SneakyThrows - public void testGetPair_whenCalled_thenReturn() { - long liveDocCount = 1l; - int[] docId = { 2 }; - long vectorAddress = 3l; - int dimension = 4; - BytesRef bytesRef = new BytesRef(); - - BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); - when(binaryDocValues.cost()).thenReturn(liveDocCount); - when(binaryDocValues.nextDoc()).thenReturn(docId[0], NO_MORE_DOCS); - when(binaryDocValues.binaryValue()).thenReturn(bytesRef); - - VectorTransfer vectorTransfer = mock(VectorTransfer.class); - when(vectorTransfer.getSerializationMode(any(BytesRef.class))).thenReturn(SerializationMode.COLLECTIONS_OF_BYTES); - when(vectorTransfer.getVectorAddress()).thenReturn(vectorAddress); - when(vectorTransfer.getDimension()).thenReturn(dimension); - - // Run - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(binaryDocValues, vectorTransfer); - - // Verify - verify(vectorTransfer).init(liveDocCount); - verify(vectorTransfer).getSerializationMode(any(BytesRef.class)); - verify(vectorTransfer).transfer(any(BytesRef.class)); - verify(vectorTransfer).close(); - - assertTrue(Arrays.equals(docId, pair.docs)); - assertEquals(vectorAddress, pair.getVectorAddress()); - assertEquals(dimension, pair.getDimension()); - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); - } public void testCalculateArraySize() { int numVectors = 4; diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index dc9a97fbf..316582f6c 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -14,6 +14,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.VectorDataType; @@ -56,7 +57,7 @@ public void testIndexAllocation_close() throws InterruptedException { } Map<String, Object> parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); @@ -117,7 +118,7 @@ public void testClose_whenBinaryFiass_thenSuccess() { VectorDataType.BINARY.getValue() ); long vectorMemoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors * dataLength); - JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 51f95d29a..8a38cadb5 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -15,6 +15,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNICommons; @@ -32,6 +33,8 @@ import java.util.Arrays; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; @@ -56,7 +59,7 @@ public void testIndexLoadStrategy_load() throws IOException { } Map<String, Object> parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); @@ -104,7 +107,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException { VectorDataType.BINARY.getValue() ); long memoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors); - JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index ae9ad7106..e8c7c9488 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -85,7 +85,7 @@ public static void setUpClass() throws IOException { public void testCreateIndex_invalid_engineNotSupported() { expectThrows( IllegalArgumentException.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( new int[] {}, 0, 0, @@ -99,21 +99,14 @@ public void testCreateIndex_invalid_engineNotSupported() { public void testCreateIndex_invalid_engineNull() { expectThrows( Exception.class, - () -> JNIService.createIndex( - new int[] {}, - 0, - 0, - "test", - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - null - ) + () -> TestUtils.createIndex(new int[] {}, 0, 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null) ); } public void testCreateIndex_nmslib_invalid_noSpaceType() { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -132,7 +125,7 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept Path tmpFile1 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors1[0].length, @@ -148,7 +141,7 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept Path tmpFile2 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress2, vectors2[0].length, @@ -167,7 +160,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( null, memoryAddress, 0, @@ -179,7 +172,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, 0, 0, @@ -191,7 +184,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, 0, @@ -203,12 +196,12 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) + () -> TestUtils.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) ); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, 0, @@ -227,7 +220,7 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -253,7 +246,7 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -273,7 +266,7 @@ public void testCreateIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -285,7 +278,7 @@ public void testCreateIndex_nmslib_valid() throws IOException { tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -309,7 +302,7 @@ public void testCreateIndex_faiss_invalid_noSpaceType() { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -328,7 +321,7 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti Path tmpFile1 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors1[0].length, @@ -343,7 +336,7 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti Path tmpFile2 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors2[0].length, @@ -363,7 +356,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( null, memoryAddress, 0, @@ -375,7 +368,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, 0, 0, @@ -387,7 +380,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -399,7 +392,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -411,7 +404,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -431,7 +424,7 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -451,7 +444,7 @@ public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOExceptio Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -469,7 +462,7 @@ public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOExc Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -492,7 +485,7 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -516,7 +509,7 @@ public void testLoadIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -539,7 +532,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddress, testData.indexData.getDimension(), @@ -627,7 +620,7 @@ public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOExcept Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -653,7 +646,7 @@ public void testCreateIndex_faiss_valid() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile1 = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -670,7 +663,7 @@ public void testCreateIndex_faiss_valid() throws IOException { public void testCreateIndex_binary_faiss_valid() { Path tmpFile1 = createTempFile(); long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddr, testData.indexData.getDimension(), @@ -726,7 +719,7 @@ public void testLoadIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -762,7 +755,7 @@ public void testLoadIndex_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -792,7 +785,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -822,7 +815,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -855,7 +848,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -881,7 +874,7 @@ public void testQueryIndex_faiss_valid() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -942,7 +935,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testDataNested.indexData.docs, testData.loadDataToMemoryAddress(), testDataNested.indexData.getDimension(), @@ -985,7 +978,7 @@ public void testQueryBinaryIndex_faiss_valid() { for (String method : methods) { Path tmpFile = createTempFile(); long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddr, testData.indexData.getDimension(), @@ -1064,7 +1057,7 @@ public void testFree_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -1088,7 +1081,7 @@ public void testFree_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -1211,7 +1204,7 @@ private long transferVectors(int numDuplicates) { return trainPointer1; } - public void testCreateIndexFromTemplate() throws IOException { + public void createIndexFromTemplate() throws IOException { long trainPointer1 = JNIService.transferVectors(0, testData.indexData.vectors); assertNotEquals(0, trainPointer1); @@ -1412,7 +1405,7 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac private String createFaissHNSWIndex(SpaceType spaceType) throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 6676ee154..741717116 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -19,7 +19,9 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.util.Collections; @@ -397,11 +399,11 @@ private void initBinaryData() { } public long loadDataToMemoryAddress() { - return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length); + return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length, true); } public long loadBinaryDataToMemoryAddress() { - return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length); + return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true); } @AllArgsConstructor @@ -414,4 +416,15 @@ public static class Pair { public float[][] vectors; } } + + public static void createIndex(int[] ids, long address, int dimension, String name, Map<String, Object> parameters, KNNEngine engine) { + if (engine != KNNEngine.FAISS) { + JNIService.createIndex(ids, address, dimension, name, parameters, engine); + } else { + // We can initialize numDocs as 0, this will just not reserve anything. + long indexAddress = JNIService.initIndexFromScratch(0, dimension, parameters, engine); + JNIService.insertToIndex(ids, address, dimension, parameters, indexAddress, engine); + JNIService.writeIndex(name, indexAddress, engine, parameters); + } + } }