From 047becbb46343d6ea86cd18fe8fe0827b6d22668 Mon Sep 17 00:00:00 2001 From: Andrew Klepchick Date: Fri, 2 Aug 2024 17:36:11 -0700 Subject: [PATCH] Iterative Vector Insertion (#1840) * Rebased with new version of k-NN Signed-off-by: Andrew Klepchick * Optimized faiss insertion Signed-off-by: Andrew Klepchick * Optimized threadCount logic Signed-off-by: Andrew Klepchick * Removed IDEA files Signed-off-by: Andrew Klepchick * Removed unnecessary cmake file Signed-off-by: Andrew Klepchick * Added comments to new functions Signed-off-by: Andrew Klepchick * Removed createIndex and fixed test cases that use it Signed-off-by: Andrew Klepchick * Removed unused code Signed-off-by: Andrew Klepchick * Explained zero initialization for vector transfer Signed-off-by: Andrew Klepchick * Added locale Signed-off-by: Andrew Klepchick * Spotless Apply Signed-off-by: Andrew Klepchick * Account for zero documents in finished batch Signed-off-by: Andrew Klepchick * Changed where we check for zero docs Signed-off-by: Andrew Klepchick * Changed tip for return Signed-off-by: Andrew Klepchick * Use unique pointers to make sure resources are released on exception Signed-off-by: Andrew Klepchick * Moved createIndex to testUtils Signed-off-by: Andrew Klepchick * Fixed memory management so that the underlying index is not deleted after initialized Signed-off-by: Andrew Klepchick * Created new KNNIndexBuilder graph to make index building more modular Signed-off-by: Andrew Klepchick * Streamlined logic in KNNIndexBuilder. Signed-off-by: Andrew Klepchick * Cleaned up unnecessary code in KNN80DocValuesConsumer Signed-off-by: Andrew Klepchick * Fixed memory management process Signed-off-by: Andrew Klepchick * Added note about index initialization in faiss_index_service Signed-off-by: Andrew Klepchick * Accounted for case where the exception happens after the indexWriter is released. Signed-off-by: Andrew Klepchick * Delete jni/src/.idea/modules.xml Signed-off-by: Andrew Klepchick * Delete jni/src/.idea/vcs.xml Signed-off-by: Andrew Klepchick * Delete jni/src/.idea/workspace.xml Signed-off-by: Andrew Klepchick * Spotless apply and free iterative index on exception Signed-off-by: Andrew Klepchick * Undid hack for checking first document metrics Signed-off-by: Andrew Klepchick * Removed print statements Signed-off-by: Andrew Klepchick * Free Vector Transfer on batch ingestion Signed-off-by: Andrew Klepchick * Undid free Signed-off-by: Andrew Klepchick * Fixed check for transfer ready Signed-off-by: Andrew Klepchick * Don't crash when zero vectors inserted? Signed-off-by: Andrew Klepchick * Reverted to old insertion process? Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Added back createOutput Signed-off-by: Andrew Klepchick * Removed prior createOutput Signed-off-by: Andrew Klepchick * Test remaking vectorTransfer Signed-off-by: Andrew Klepchick * Test restructuring of insertion Signed-off-by: Andrew Klepchick * Fixed case where vector address is immediately discarded Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Split Index Builder into multiple classes Signed-off-by: Andrew Klepchick * Fixed descriptions of functions in faiss_index_service Signed-off-by: Andrew Klepchick * Added back copyright files Signed-off-by: Andrew Klepchick * Removed unused builder names Signed-off-by: Andrew Klepchick * Modified tests to work with new insertion methods Signed-off-by: Andrew Klepchick * Track index insertions Signed-off-by: Andrew Klepchick * Tracked insertions for binary indices Signed-off-by: Andrew Klepchick * Added back insertIds Signed-off-by: Andrew Klepchick * Added check for freeVectorData to see if it works with an already deleted address Signed-off-by: Andrew Klepchick * Cleaned up logs and comments in KNNIndexBuilder Signed-off-by: Andrew Klepchick * Restructured the logic for KNNIndexBuilder Signed-off-by: Andrew Klepchick * Changed package name of KNNIndexBuilder Signed-off-by: Andrew Klepchick * Changed all package names and deleted unnecessary headers Signed-off-by: Andrew Klepchick * Fixed for loop Signed-off-by: Andrew Klepchick * Removed createIndex methods for faiss index service Signed-off-by: Andrew Klepchick * Fixed package to fit naming conventions Signed-off-by: Andrew Klepchick * Changed name of index builder Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Added comments to NativeIndexBuilder and restructured Signed-off-by: Andrew Klepchick * Added deletion for memoryAddress Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Changed naming of classes to Writer and changed package name to fit conventions Signed-off-by: Andrew Klepchick * Changed NativeIndexInfo and NativeVectorInfo to follow builder pattern Signed-off-by: Andrew Klepchick * Added feature to changelog Signed-off-by: Andrew Klepchick * Added class descriptions to each NativeIndexWriter Signed-off-by: Andrew Klepchick * Changed name to getBytesPerVector Signed-off-by: Andrew Klepchick * Added == false instead of ! for readability Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Fixed naming in docvaluesconsumer Signed-off-by: Andrew Klepchick * SpotlessApply Signed-off-by: Andrew Klepchick * Made it so that we don't reuse testValues and removed a foot gun Signed-off-by: Andrew Klepchick * Removed another foot gun in getIndexInfo Signed-off-by: Andrew Klepchick * Fixed javadoc Signed-off-by: Andrew Klepchick * Added deletion on exception cases Signed-off-by: Andrew Klepchick * Removed unnecessary delete (NativeIndexWriter will handle deletion of vectors on exception) Signed-off-by: Andrew Klepchick * Added correct logger and getWriter method to NativeIndexWriter Signed-off-by: Andrew Klepchick * Ensured memory safety on JNI layer so that Java doesn't have to wrap everything in a try catch loop. Signed-off-by: Andrew Klepchick * Refactored NativeIndexWriter and added comments to FaissService Signed-off-by: Andrew Klepchick * Removed free in the JNIExport since index will always be freed in writeIndex. Signed-off-by: Andrew Klepchick * Changed getVectorTransfer back to accept VectorDataType Signed-off-by: Andrew Klepchick * Reverted free since not guaranteed to be IDMap. Signed-off-by: Andrew Klepchick * Added all processes in addKNNBinaryField to NativeIndexWriter.createKNNIndex Signed-off-by: Andrew Klepchick * Fixed javadoc Signed-off-by: Andrew Klepchick * Applied spotless Signed-off-by: Andrew Klepchick * Added back writeFooter Signed-off-by: Andrew Klepchick * Removed threadCount fron writeIndex Signed-off-by: Andrew Klepchick * Removed redundancies in KNN80DocValuesConsumer Signed-off-by: Andrew Klepchick * Removed serializationMode Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Removed double free test as we don't have to worry about that anymore Signed-off-by: Andrew Klepchick * Accounted for HNSWSQ in index service Signed-off-by: Andrew Klepchick * Removed delete in catch Signed-off-by: Andrew Klepchick * Fixed faiss tests to work with writeIndex Signed-off-by: Andrew Klepchick --------- Signed-off-by: Andrew Klepchick --- CHANGELOG.md | 1 + jni/include/faiss_index_service.h | 82 +++--- jni/include/faiss_wrapper.h | 9 +- .../org_opensearch_knn_jni_FaissService.h | 49 +++- jni/src/faiss_index_service.cpp | 190 +++++++++--- jni/src/faiss_wrapper.cpp | 72 +++-- .../org_opensearch_knn_jni_FaissService.cpp | 76 ++++- jni/tests/commons_test.cpp | 2 + jni/tests/faiss_index_service_test.cpp | 30 +- jni/tests/faiss_wrapper_test.cpp | 94 +++++- jni/tests/mocks/faiss_index_service_mock.h | 25 +- .../KNN80Codec/KNN80DocValuesConsumer.java | 240 +-------------- .../codec/nativeindex/NativeIndexWriter.java | 273 ++++++++++++++++++ .../nativeindex/NativeIndexWriterScratch.java | 124 ++++++++ .../NativeIndexWriterScratchIter.java | 72 +++++ .../NativeIndexWriterTemplate.java | 101 +++++++ .../index/codec/transfer/VectorTransfer.java | 7 + .../codec/transfer/VectorTransferByte.java | 19 +- .../codec/transfer/VectorTransferFloat.java | 11 +- .../knn/index/codec/util/KNNCodecUtil.java | 43 ++- .../org/opensearch/knn/jni/FaissService.java | 64 +++- .../org/opensearch/knn/jni/JNIService.java | 123 ++++++-- .../index/codec/util/KNNCodecUtilTests.java | 11 +- .../memory/NativeMemoryAllocationTests.java | 5 +- .../memory/NativeMemoryLoadStrategyTests.java | 7 +- .../opensearch/knn/jni/JNIServiceTests.java | 93 +++--- .../java/org/opensearch/knn/TestUtils.java | 13 + 27 files changed, 1323 insertions(+), 513 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java diff --git a/CHANGELOG.md b/CHANGELOG.md index a5c641b8f..06f5fb4de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Enhancements +* Add functionality to iteratively insert vectors into a faiss index to improve the memory footprint during indexing. [#1840](https://github.com/opensearch-project/k-NN/pull/1840) ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9..f147a6e7e 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -31,36 +31,38 @@ namespace faiss_wrapper { class IndexService { public: IndexService(std::unique_ptr faissMethods); - //TODO Remove dependency on JNIUtilInterface and JNIEnv - //TODO Reduce the number of parameters - /** - * Create index + * Initialize index * * @param jniUtil jni util * @param env jni environment * @param metric space type for distance calculation * @param indexDescription index description to be used by faiss index factory * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters); + /** + * Add vectors to index + * + * @param dim dimension of vectors * @param numIds number of vectors * @param threadCount number of thread count to be used while adding data * @param vectorsAddress memory address which is holding vector data - * @param ids a list of document ids for corresponding vectors + * @param idMapAddress memory address of the native index object + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress); + /** + * Write index to disk + * + * @param threadCount number of thread count to be used while adding data * @param indexPath path to write index - * @param parameters parameters to be applied to faiss index + * @param idMap memory address of the native index object */ - virtual void createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters); + virtual void writeIndex(std::string indexPath, jlong idMapAddress); virtual ~IndexService() = default; protected: std::unique_ptr faissMethods; @@ -76,7 +78,21 @@ class BinaryIndexService : public IndexService { //TODO Reduce the number of parameters BinaryIndexService(std::unique_ptr faissMethods); /** - * Create binary index + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + /** + * Add vectors to index * * @param jniUtil jni util * @param env jni environment @@ -86,23 +102,23 @@ class BinaryIndexService : public IndexService { * @param numIds number of vectors * @param threadCount number of thread count to be used while adding data * @param vectorsAddress memory address which is holding vector data - * @param ids a list of document ids for corresponding vectors + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) override; + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data * @param indexPath path to write index + * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters - ) override; + virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; virtual ~BinaryIndexService() = default; }; diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 5ad0dedc4..574efb6fd 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -18,10 +18,11 @@ namespace knn_jni { namespace faiss_wrapper { - // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. - // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ, IndexService* indexService); + jlong InitIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, IndexService *indexService); + + void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddr, jint threadCount, IndexService *indexService); + + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jlong indexAddr, IndexService *indexService); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 025fb12e8..19e13d402 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,23 +18,54 @@ #ifdef __cplusplus extern "C" { #endif - /* * Class: org_opensearch_knn_jni_FaissService - * Method: createIndex + * Method: initIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); - +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService - * Method: createBinaryIndex + * Method: initBinaryIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); - +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af..cfb30cdb0 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -57,76 +57,179 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} -void IndexService::createIndex( +jlong IndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Create index using Faiss factory method + std::unique_ptr indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + std::unique_ptr idMap (faissMethods->indexIdMap(indexWriter.get())); + + /* + * NOTE: The process of memory allocation is currently only implemented for HNSW. + * This technique of checking the types of the index and subindices should be generalized into + * another function. + */ + + // Check to see if the current index is HNSW + faiss::IndexHNSWFlat * hnsw = dynamic_cast(idMap->index); + if(hnsw != NULL) { + // Check to see if the HNSW storage is IndexFlat + faiss::IndexFlat * storage = dynamic_cast(hnsw->storage); + if(storage != NULL) { + // Allocate enough memory for all of the vectors we plan on inserting + // We do this to avoid unnecessary memory allocations during insert + storage->codes.reserve(dim * numVectors * 4); + } + } + faiss::IndexHNSWSQ * hnswSq = dynamic_cast(idMap->index); + if(hnswSq != NULL) { + // Check to see if the HNSW storage is IndexFlat + faiss::IndexFlat * storage = dynamic_cast(hnswSq->storage); + if(storage != NULL) { + // Allocate enough memory for all of the vectors we plan on inserting + // We do this to avoid unnecessary memory allocations during insert + storage->codes.reserve(dim * numVectors * 2); + } + } + indexWriter.release(); + return reinterpret_cast(idMap.release()); +} + +void IndexService::insertToIndex( int dim, int numIds, int threadCount, int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters + std::vector & ids, + jlong idMapAddress ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddress); + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) dim); if(numVectors == 0) { - throw std::runtime_error("Number of vectors cannot be 0"); + return; } if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - std::unique_ptr indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); // Add vectors - std::unique_ptr idMap(faissMethods->indexIdMap(indexWriter.get())); idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} - // Write the index to disk - faissMethods->writeIndex(idMap.get(), indexPath.c_str()); +void IndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + std::unique_ptr idMap (reinterpret_cast (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + delete idMap->index; + throw std::runtime_error("Failed to write index to disk"); + } + // Free the memory used by the index + delete idMap->index; } BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} -void BinaryIndexService::createIndex( +jlong BinaryIndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, int dim, - int numIds, + int numVectors, int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, std::unordered_map parameters ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddress); + // Create index using Faiss factory method + std::unique_ptr indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - if (dim % 8 != 0) { - throw std::runtime_error("Dimensions should be multiply of 8"); + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + + /* + * NOTE: The process of memory allocation is currently only implemented for HNSW. + * This technique of checking the types of the index and subindices should be generalized into + * another function. + */ + + // Check to see if the current index is BinaryHNSW + faiss::IndexBinaryHNSW * hnsw = dynamic_cast(idMap->index); + + if(hnsw != NULL) { + // Check to see if the HNSW storage is IndexBinaryFlat + faiss::IndexBinaryFlat * storage = dynamic_cast(hnsw->storage); + if(storage != NULL) { + // Allocate enough memory for all of the vectors we plan on inserting + // We do this to avoid unnecessary memory allocations during insert + storage->xb.reserve(dim / 8 * numVectors); + } + } + indexWriter.release(); + return reinterpret_cast(idMap.release()); +} + +void BinaryIndexService::insertToIndex( + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); if(numVectors == 0) { @@ -137,27 +240,34 @@ void BinaryIndexService::createIndex( throw std::runtime_error("Number of IDs does not match number of vectors"); } - std::unique_ptr indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + faiss::IndexBinaryIDMap * idMap = reinterpret_cast (idMapAddress); // Add vectors - std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} + +void BinaryIndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + + std::unique_ptr idMap (reinterpret_cast (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + delete idMap->index; + throw std::runtime_error("Failed to write index to disk"); + } - // Write the index to disk - faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + // Free the memory used by the index + delete idMap->index; } } // namespace faiss_wrapper diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 1d4437414..0e1029ecf 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -88,24 +88,13 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); -void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ, IndexService* indexService) { - if (idsJ == nullptr) { - throw std::runtime_error("IDs cannot be null"); - } - - if (vectorsAddressJ <= 0) { - throw std::runtime_error("VectorsAddress cannot be less than 0"); - } +jlong knn_jni::faiss_wrapper::InitIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, IndexService* indexService) { if(dimJ <= 0) { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); - } - if (parametersJ == nullptr) { throw std::runtime_error("Parameters cannot be null"); } @@ -124,8 +113,8 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN // Dimension int dim = (int)dimJ; - // Number of vectors - int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + // Number of docs + int docs = (int)numDocs; // Index description jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); @@ -138,25 +127,60 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); } + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); +} + +void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong index_ptr, jint threadCount, IndexService* indexService) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + // Dimension + int dim = (int)dimJ; + + // Number of vectors + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + // Vectors address int64_t vectorsAddress = (int64_t)vectorsAddressJ; // Ids auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - // Index path - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + // Create index + indexService->insertToIndex(dim, numIds, threadCount, vectorsAddress, ids, index_ptr); +} - // Extra parameters - // TODO: parse the entire map and remove jni object - std::unordered_map subParametersCpp; - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { - subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); +void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, + jstring indexPathJ, jlong index_ptr, IndexService* indexService) { + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); } - // end parameters to pass + + // Index path + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); // Create index - indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp); + indexService->writeIndex(indexPathCpp, index_ptr); } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 2394e2951..2b804a672 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -39,37 +39,85 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { jniUtil.Uninitialize(env); } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &indexService); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &indexService); delete reinterpret_cast*>(vectorsAddressJ); } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! jniUtil.CatchCppExceptionAndThrowJava(env); } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) + jlong indexAddress, jint threadCount) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &binaryIndexService); - - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &binaryIndexService); delete reinterpret_cast*>(vectorsAddressJ); + } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &binaryIndexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp index 630358919..98def8807 100644 --- a/jni/tests/commons_test.cpp +++ b/jni/tests/commons_test.cpp @@ -70,6 +70,8 @@ TEST(CommonsTests, BasicAssertions) { currentIndex++; } } + // Check that freeing vector data works + knn_jni::commons::freeVectorData(memoryAddress); } TEST(CommonTests, GetIntegerMethodParam) { diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index f876edced..1f00f6a1d 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -64,18 +64,9 @@ TEST(CreateIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); - indexService.createIndex( - &mockJNIUtil, - jniEnv, - metricType, - indexDescription, - dim, - numIds, - threadCount, - (int64_t) &vectors, - ids, - indexPath, - parametersMap); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } TEST(CreateBinaryIndexTest, BasicAssertions) { @@ -119,16 +110,7 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(mockFaissMethods)); - indexService.createIndex( - &mockJNIUtil, - jniEnv, - metricType, - indexDescription, - dim, - numIds, - threadCount, - (int64_t) &vectors, - ids, - indexPath, - parametersMap); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5ae443837..a1839c6ce 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -32,6 +32,70 @@ float rangeSearchRandomDataMin = -50; float rangeSearchRandomDataMax = 50; float rangeSearchRadius = 20000; +void createIndexIteratively( + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, + std::vector & ids, + std::vector & vectors, + int dim, + std::string & indexPath, + std::unordered_map parametersMap, + IndexService * indexService, + int insertions = 10 + ) { + long numDocs = ids.size(); + if(numDocs % insertions != 0) { + throw std::invalid_argument("Number of documents should be divisible by number of insertions"); + } + long docsPerInsertion = numDocs / insertions; + long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); + for(int i = 0; i < insertions; i++) { + int start_idx = i * docsPerInsertion; + int end_idx = start_idx + docsPerInsertion; + std::vector insertIds; + std::vector insertVecs; + for(int j = start_idx; j < end_idx; j++) { + insertIds.push_back(j); + for(int k = 0; k < dim; k++) { + insertVecs.push_back(vectors[j * dim + k]); + } + } + knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); + } + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); +} + +void createBinaryIndexIteratively( + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, + std::vector & ids, + std::vector & vectors, + int dim, + std::string & indexPath, + std::unordered_map parametersMap, + IndexService * indexService, + int insertions = 10 + ) { + long numDocs = ids.size();; + long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); + for(int i = 0; i < insertions; i++) { + int start_idx = numDocs * i / insertions; + int end_idx = numDocs * (i + 1) / insertions; + int docs_to_insert = end_idx - start_idx; + if(docs_to_insert == 0) continue; + std::vector insertIds; + std::vector insertVecs; + for(int j = start_idx; j < end_idx; j++) { + insertIds.push_back(j); + for(int k = 0; k < dim / 8; k++) { + insertVecs.push_back(vectors[j * (dim / 8) + k]); + } + } + knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); + } + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); +} + TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; @@ -63,13 +127,15 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Create the index std::unique_ptr faissMethods(new FaissMethods()); NiceMock mockIndexService(std::move(faissMethods)); - EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) .Times(1); - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) &vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap, &mockIndexService); + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); } TEST(FaissCreateBinaryIndexTest, BasicAssertions) { @@ -103,14 +169,16 @@ TEST(FaissCreateBinaryIndexTest, BasicAssertions) { // Create the index std::unique_ptr faissMethods(new FaissMethods()); NiceMock mockIndexService(std::move(faissMethods)); - EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) .Times(1); // This method calls delete vectors at the end - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) &vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap, &mockIndexService); + createBinaryIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); } TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { @@ -683,10 +751,8 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Create the index std::unique_ptr faissMethods(new FaissMethods()); knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong)&vectors, dim, (jstring)&indexPath, - (jobject)¶metersMap, &IndexService); + + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &IndexService); // Make sure index can be loaded std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); diff --git a/jni/tests/mocks/faiss_index_service_mock.h b/jni/tests/mocks/faiss_index_service_mock.h index 7af08c82e..285e34053 100644 --- a/jni/tests/mocks/faiss_index_service_mock.h +++ b/jni/tests/mocks/faiss_index_service_mock.h @@ -23,20 +23,37 @@ class MockIndexService : public IndexService { public: MockIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {}; MOCK_METHOD( - void, - createIndex, + long, + initIndex, ( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, + int dim, + int numIds, + int threadCount, + StringToJObjectMap parameters + ), + (override)); + MOCK_METHOD( + void, + insertToIndex, + ( int dim, int numIds, int threadCount, int64_t vectorsAddress, - std::vector ids, + std::vector & ids, + long indexPtr + ), + (override)); + MOCK_METHOD( + void, + writeIndex, + ( std::string indexPath, - StringToJObjectMap parameters + long indexPtr ), (override)); }; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 55ac5c597..d2117a3bc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -7,74 +7,37 @@ import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.store.ChecksumIndexInput; import org.opensearch.common.StopWatch; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import org.opensearch.knn.index.codec.transfer.VectorTransferByte; -import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; -import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.store.FilterDirectory; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.stats.KNNGraphValue; -import java.io.Closeable; import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.HashMap; -import java.util.Map; -import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; -import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** * This class writes the KNN docvalues to the segments */ @Log4j2 -class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { +class KNN80DocValuesConsumer extends DocValuesConsumer { private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class); private final DocValuesConsumer delegatee; private final SegmentWriteState state; - private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; - KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) { this.delegatee = delegatee; this.state = state; @@ -113,156 +76,7 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) throws IOException { - // Get values to be indexed - BinaryDocValues values = valuesProducer.getBinary(field); - final KNNEngine knnEngine = getKNNEngine(field); - final String engineFileName = buildEngineFileName( - state.segmentInfo.name, - knnEngine.getVersion(), - field.name, - knnEngine.getExtension() - ); - final String indexPath = Paths.get( - ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName - ).toString(); - - // Determine if we are creating an index from a model or from scratch - NativeIndexCreator indexCreator; - KNNCodecUtil.Pair pair; - Map fieldAttributes = field.attributes(); - VectorDataType vectorDataType; - - if (fieldAttributes.containsKey(MODEL_ID)) { - String modelId = fieldAttributes.get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - vectorDataType = model.getModelMetadata().getVectorDataType(); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); - } else { - // get vector data type from field attributes or provide default value - vectorDataType = VectorDataType.get( - fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) - ); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } - - // Skip index creation if no vectors or docs in segment - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); - return; - } - - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType); - - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - recordMergeStats(pair.docs.length, arraySize); - } - - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - - if (isRefresh) { - recordRefreshStats(); - } - - // Ensure engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper - state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); - writeFooter(indexPath, engineFileName); - } - - private void recordMergeStats(int length, long arraySize) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); - KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); - KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); - KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length); - KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); - } - - private void recordRefreshStats() { - KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); - } - - private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = new HashMap<>(); - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); - - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndexFromTemplate( - pair.docs, - pair.getVectorAddress(), - pair.getDimension(), - indexPath, - model.getModelBlob(), - parameters, - knnEngine - ); - return null; - }); - } - - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { - Map parameters = new HashMap<>(); - Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(PARAMETERS); - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); - } - - // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine - && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); - } - - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - // Pass the path for the nms library to save the file - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); - return null; - }); + NativeIndexWriter.getWriter(field).createKNNIndex(field, valuesProducer, state, isMerge, isRefresh); } /** @@ -317,52 +131,4 @@ public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) t public void close() throws IOException { delegatee.close(); } - - @FunctionalInterface - private interface NativeIndexCreator { - void createIndex() throws IOException; - } - - private void writeFooter(String indexPath, String engineFileName) throws IOException { - // Opens the engine file that was created and appends a footer to it. The footer consists of - // 1. A Footer magic number (int - 4 bytes) - // 2. A checksum algorithm id (int - 4 bytes) - // 3. A checksum (long - bytes) - // The checksum is computed on all the bytes written to the file up to that point. - // Logic where footer is written in Lucene can be found here: - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 - OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); - ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); - byteBuffer.putInt(FOOTER_MAGIC); - byteBuffer.putInt(0); - os.write(byteBuffer.array()); - os.flush(); - - ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); - checksumIndexInput.seek(checksumIndexInput.length()); - long value = checksumIndexInput.getChecksum(); - checksumIndexInput.close(); - - if (isChecksumValid(value)) { - throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); - } - - // Write the CRC checksum to the end of the OutputStream and close the stream - byteBuffer.putLong(0, value); - os.write(byteBuffer.array()); - os.close(); - } - - private boolean isChecksumValid(long value) { - // Check pulled from - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 - return (value & CRC32_CHECKSUM_SANITY) != 0; - } - - private VectorTransfer getVectorTransfer(VectorDataType vectorDataType) { - if (VectorDataType.BINARY == vectorDataType) { - return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java new file mode 100644 index 000000000..8cd75c4d7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -0,0 +1,273 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Map; + +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.codec.transfer.VectorTransferByte; +import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.plugin.stats.KNNGraphValue; + +import lombok.Builder; +import lombok.NonNull; +import lombok.Value; +import lombok.extern.log4j.Log4j2; + +import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; + +/** + * Abstract class to build the KNN index and write it to disk + */ +@Log4j2 +public abstract class NativeIndexWriter { + private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; + + /** + * Class that holds info about vectors + */ + @Builder + @Value + protected static class NativeVectorInfo { + private VectorDataType vectorDataType; + private int dimension; + } + + /** + * Class that holds info about the native index + */ + @Builder + @Value + protected static class NativeIndexInfo { + private FieldInfo fieldInfo; + private KNNEngine knnEngine; + private int numDocs; + private long arraySize; + private Map parameters; + private NativeVectorInfo vectorInfo; + private String indexPath; + } + + /** + * Gets the correct writer type from fieldInfo + * + * @param fieldInfo + * @return correct NativeIndexWriter to make index specified in fieldInfo + */ + public static NativeIndexWriter getWriter(FieldInfo fieldInfo) { + final KNNEngine knnEngine = getKNNEngine(fieldInfo); + boolean fromScratch = !fieldInfo.attributes().containsKey(MODEL_ID); + boolean iterative = fromScratch && KNNEngine.FAISS == knnEngine; + if (fromScratch && iterative) { + return new NativeIndexWriterScratchIter(); + } else if (fromScratch) { + return new NativeIndexWriterScratch(); + } else { + return new NativeIndexWriterTemplate(); + } + } + + /** + * Method for creating a KNN index in the specified native library + * + * @param fieldInfo + * @param valuesProducer + * @param state + * @param isMerge + * @param isRefresh + * @throws IOException + */ + public void createKNNIndex( + FieldInfo fieldInfo, + DocValuesProducer valuesProducer, + SegmentWriteState state, + boolean isMerge, + boolean isRefresh + ) throws IOException { + BinaryDocValues values = valuesProducer.getBinary(fieldInfo); + if (KNNCodecUtil.getTotalLiveDocsCount(values) == 0) { + log.debug("No live docs for field " + fieldInfo.name); + return; + } + final KNNEngine knnEngine = getKNNEngine(fieldInfo); + final String engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getVersion(), + fieldInfo.name, + knnEngine.getExtension() + ); + final String indexPath = Paths.get( + ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), + engineFileName + ).toString(); + + state.directory.createOutput(engineFileName, state.context).close(); + NativeIndexInfo indexInfo = getIndexInfo(fieldInfo, valuesProducer, indexPath); + if (isMerge) { + startMergeStats(indexInfo.numDocs, indexInfo.arraySize); + } + if (isRefresh) { + recordRefreshStats(); + } + createIndex(indexInfo, values); + if (isMerge) { + endMergeStats(indexInfo.numDocs, indexInfo.arraySize); + } + writeFooter(indexPath, engineFileName, state); + } + + /** + * Method that makes a native index given the parameters from indexInfo + * @param indexInfo + * @param values + * @throws IOException + */ + protected abstract void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException; + + /** + * Method that generates extra index parameters to be passed to the native library + * @param fieldInfo + * @param knnEngine + * @return extra index parameters to be passed to the native library + * @throws IOException + */ + protected abstract Map getParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException; + + /** + * Method that gets the native vector info + * @param fieldInfo + * @param valuesProducer + * @return native vector info + * @throws IOException + */ + protected abstract NativeVectorInfo getVectorInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer) throws IOException; + + protected VectorTransfer getVectorTransfer(VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } + return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } + + /** + * Method that gets the native index info from a given field + * @param fieldInfo + * @param valuesProducer + * @param indexPath + * @return native index info + * @throws IOException + */ + private NativeIndexInfo getIndexInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer, String indexPath) throws IOException { + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(valuesProducer.getBinary(fieldInfo)); + NativeVectorInfo vectorInfo = getVectorInfo(fieldInfo, valuesProducer); + KNNEngine knnEngine = getKNNEngine(fieldInfo); + NativeIndexInfo indexInfo = NativeIndexInfo.builder() + .fieldInfo(fieldInfo) + .knnEngine(getKNNEngine(fieldInfo)) + .numDocs((int) numDocs) + .vectorInfo(vectorInfo) + .arraySize(numDocs * getBytesPerVector(vectorInfo)) + .parameters(getParameters(fieldInfo, knnEngine)) + .indexPath(indexPath) + .build(); + return indexInfo; + } + + private long getBytesPerVector(NativeVectorInfo vectorInfo) { + if (vectorInfo.vectorDataType == VectorDataType.BINARY) { + return vectorInfo.dimension / 8; + } else { + return vectorInfo.dimension * 4; + } + } + + private static KNNEngine getKNNEngine(@NonNull FieldInfo field) { + final String modelId = field.attributes().get(MODEL_ID); + if (modelId != null) { + var model = ModelCache.getInstance().get(modelId); + return model.getModelMetadata().getKnnEngine(); + } + final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + return KNNEngine.getEngine(engineName); + } + + private void startMergeStats(int numDocs, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); + KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); + } + + private void endMergeStats(int numDocs, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); + KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); + } + + private void recordRefreshStats() { + KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); + } + + private boolean isChecksumValid(long value) { + // Check pulled from + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 + return (value & CRC32_CHECKSUM_SANITY) != 0; + } + + private void writeFooter(String indexPath, String engineFileName, SegmentWriteState state) throws IOException { + // Opens the engine file that was created and appends a footer to it. The footer consists of + // 1. A Footer magic number (int - 4 bytes) + // 2. A checksum algorithm id (int - 4 bytes) + // 3. A checksum (long - bytes) + // The checksum is computed on all the bytes written to the file up to that point. + // Logic where footer is written in Lucene can be found here: + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 + OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); + ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); + byteBuffer.putInt(FOOTER_MAGIC); + byteBuffer.putInt(0); + os.write(byteBuffer.array()); + os.flush(); + + ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); + checksumIndexInput.seek(checksumIndexInput.length()); + long value = checksumIndexInput.getChecksum(); + checksumIndexInput.close(); + + if (isChecksumValid(value)) { + throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); + } + + // Write the CRC checksum to the end of the OutputStream and close the stream + byteBuffer.putLong(0, value); + os.write(byteBuffer.array()); + os.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java new file mode 100644 index 000000000..3a410e801 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import lombok.extern.log4j.Log4j2; + +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + +/** + * Class to build the KNN index from scratch and write it to disk + */ +@Log4j2 +public class NativeIndexWriterScratch extends NativeIndexWriter { + + protected NativeVectorInfo getVectorInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer) throws IOException { + // Hack to get the data metrics from the first document. We account for this in KNNCodecUtil. + BinaryDocValues testValues = valuesProducer.getBinary(fieldInfo); + testValues.nextDoc(); + BytesRef firstDoc = testValues.binaryValue(); + VectorDataType vectorDataType = VectorDataType.get( + fieldInfo.attributes().getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + ); + int dimension = 0; + if (vectorDataType == VectorDataType.BINARY) { + dimension = firstDoc.length * 8; + } else { + dimension = firstDoc.length / 4; + } + NativeVectorInfo vectorInfo = NativeVectorInfo.builder().vectorDataType(vectorDataType).dimension(dimension).build(); + return vectorInfo; + } + + protected Map getParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine + && VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + } + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + return parameters; + } + + protected void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException { + VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorInfo().getVectorDataType()); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, vectorTransfer, false); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex( + batch.docs, + batch.getVectorAddress(), + batch.getDimension(), + indexInfo.getIndexPath(), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java new file mode 100644 index 000000000..c3848d7e4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; + +import org.apache.lucene.index.BinaryDocValues; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import lombok.extern.log4j.Log4j2; + +/** + * Class to build the KNN index from scratch iteratively and write it to disk + */ +@Log4j2 +public class NativeIndexWriterScratchIter extends NativeIndexWriterScratch { + + @Override + protected void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException { + long indexAddress = initIndexFromScratch( + indexInfo.getNumDocs(), + indexInfo.getVectorInfo().getDimension(), + indexInfo.getKnnEngine(), + indexInfo.getParameters() + ); + while (true) { + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch( + values, + getVectorTransfer(indexInfo.getVectorInfo().getVectorDataType()), + true + ); + insertToIndex(batch, indexInfo.getKnnEngine(), indexAddress, indexInfo.getParameters()); + if (batch.finished) { + break; + } + } + writeIndex(indexAddress, indexInfo.getIndexPath(), indexInfo.getKnnEngine(), indexInfo.getParameters()); + } + + private long initIndexFromScratch(long size, int dim, KNNEngine knnEngine, Map parameters) throws IOException { + return AccessController.doPrivileged((PrivilegedAction) () -> { + return JNIService.initIndexFromScratch(size, dim, parameters, knnEngine); + }); + } + + private void insertToIndex(KNNCodecUtil.VectorBatch batch, KNNEngine knnEngine, long indexAddress, Map parameters) + throws IOException { + if (batch.docs.length == 0) { + log.debug("Index insertion called with a batch without docs."); + return; + } + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex(batch.docs, batch.getVectorAddress(), batch.getDimension(), parameters, indexAddress, knnEngine); + return null; + }); + } + + private void writeIndex(long indexAddress, String indexPath, KNNEngine knnEngine, Map parameters) throws IOException { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexPath, indexAddress, knnEngine, parameters); + return null; + }); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java new file mode 100644 index 000000000..f1cb84f97 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.jni.JNIService; + +import lombok.extern.log4j.Log4j2; + +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; + +/** + * Abstract class to build the KNN index from a template model and write it to disk + */ +@Log4j2 +public class NativeIndexWriterTemplate extends NativeIndexWriter { + + protected void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException { + String modelId = indexInfo.getFieldInfo().attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + byte[] modelBlob = model.getModelBlob(); + IndexUtil.updateVectorDataTypeToParameters(indexInfo.getParameters(), model.getModelMetadata().getVectorDataType()); + // This is carried over from the old index creation process. Why can't we get the vector data type + // by just reading it from the field? + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch( + values, + getVectorTransfer(indexInfo.getVectorInfo().getVectorDataType()), + false + ); + + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate( + batch.docs, + batch.getVectorAddress(), + batch.getDimension(), + indexInfo.getIndexPath(), + modelBlob, + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } + + @Override + protected Map getParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + String modelId = fieldInfo.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + return parameters; + } + + @Override + protected NativeVectorInfo getVectorInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer) throws IOException { + BinaryDocValues testValues = valuesProducer.getBinary(fieldInfo); + testValues.nextDoc(); + BytesRef firstDoc = testValues.binaryValue(); + String modelId = fieldInfo.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); + int dimension = 0; + if (vectorDataType == VectorDataType.BINARY) { + dimension = firstDoc.length * 8; + } else { + dimension = firstDoc.length / 4; + } + NativeVectorInfo vectorInfo = NativeVectorInfo.builder().vectorDataType(vectorDataType).dimension(dimension).build(); + return vectorInfo; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java index c23bd4317..5c847fcc4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java @@ -51,4 +51,11 @@ public VectorTransfer(final long vectorsStreamingMemoryLimit) { * @return serialization mode */ abstract public SerializationMode getSerializationMode(final BytesRef bytesRef); + + /** + * Get number of documents not transferred + * + * @return number of documents not transferred + */ + abstract public int numPendingDocs(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java index e81ac35fc..cf4066828 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java @@ -60,9 +60,26 @@ public SerializationMode getSerializationMode(final BytesRef bytesRef) { return SerializationMode.COLLECTIONS_OF_BYTES; } + @Override + public int numPendingDocs() { + return vectorList.size(); + } + private void transfer() { int lengthOfVector = dimension / 8; - vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * lengthOfVector); + if (totalLiveDocs != 0) { + vectorAddress = JNICommons.storeByteVectorData( + vectorAddress, + vectorList.toArray(new byte[][] {}), + totalLiveDocs * lengthOfVector + ); + } else { + vectorAddress = JNICommons.storeByteVectorData( + vectorAddress, + vectorList.toArray(new byte[][] {}), + vectorList.size() * lengthOfVector + ); + } vectorList.clear(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java index a9c792398..b4ce95bb1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java @@ -64,8 +64,17 @@ public SerializationMode getSerializationMode(final BytesRef bytesRef) { return KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef); } + @Override + public int numPendingDocs() { + return vectorList.size(); + } + private void transfer() { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + if (totalLiveDocs != 0) { + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + } else { + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), vectorList.size() * dimension); + } vectorList.clear(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index ea14fe883..e30154c2f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -24,7 +24,7 @@ public class KNNCodecUtil { public static final int FLOAT_BYTE_SIZE = 4; @AllArgsConstructor - public static final class Pair { + public static final class VectorBatch { public int[] docs; @Getter @Setter @@ -32,33 +32,48 @@ public static final class Pair { @Getter @Setter private int dimension; - public SerializationMode serializationMode; + public boolean finished; } /** * Extract docIds and vectors from binary doc values. - * - * @param values Binary doc values - * @param vectorTransfer Utility to make transfer - * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors - * @throws IOException thrown when unable to get binary of vectors - */ - public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException { + * + * @param values Binary doc values + * @param vectorTransfer Utility to make transfer + * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors + * @throws IOException thrown when unable to get binary of vectors + */ + public static KNNCodecUtil.VectorBatch getVectorBatch( + final BinaryDocValues values, + final VectorTransfer vectorTransfer, + boolean iterative + ) throws IOException { List docIdList = new ArrayList<>(); - SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - vectorTransfer.init(getTotalLiveDocsCount(values)); + if (iterative) { + // Initializing with a value of zero means to only allocate as much memory on JNI as + // we have inserted for vectors in java side + vectorTransfer.init(0); + } else { + vectorTransfer.init(getTotalLiveDocsCount(values)); + } for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); - serializationMode = vectorTransfer.getSerializationMode(bytesref); vectorTransfer.transfer(bytesref); docIdList.add(doc); + // Semi-hacky way to check if the streaming limit has been reached + if (iterative && vectorTransfer.numPendingDocs() == 0) { + break; + } } vectorTransfer.close(); - return new KNNCodecUtil.Pair( + + boolean finished = values.docID() == DocIdSetIterator.NO_MORE_DOCS; + + return new KNNCodecUtil.VectorBatch( docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorTransfer.getVectorAddress(), vectorTransfer.getDimension(), - serializationMode + finished ); } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 4f57b616a..a402be1f3 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -50,32 +50,70 @@ class FaissService { } /** - * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initIndex(long numDocs, int dim, Map parameters); + + /** + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initBinaryIndex(long numDocs, int dim, Map parameters); + + /** + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer - * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this - * issue + * created the memory address and that should only free up the memory. * - * @param ids array of ids mapping to the data passed in + * @param ids ids of documents * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to - * @param parameters parameters to build index + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion */ - public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + public static native void insertToIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); /** - * Create a binary index for the native library The memory occupied by the vectorsAddress will be freed up during the + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer - * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this - * issue + * created the memory address and that should only free up the memory. * - * @param ids array of ids mapping to the data passed in + * @param ids ids of documents * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion + */ + public static native void insertToBinaryIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); + + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored + * @param indexPath path to save index file to + */ + public static native void writeIndex(long indexAddress, String indexPath); + + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored * @param indexPath path to save index file to - * @param parameters parameters to build index */ - public static native void createBinaryIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + public static native void writeBinaryIndex(long indexAddress, String indexPath); /** * Create an index for the native library with a provided template index diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index de696b5ce..d428d1ee4 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -13,16 +13,98 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.common.Nullable; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.engine.KNNEngine; +import java.util.Locale; import java.util.Map; /** * Service to distribute requests to the proper engine jni service */ public class JNIService { + /** + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param knnEngine knn engine + * @return address of the index in memory + */ + public static long initIndexFromScratch(long numDocs, int dim, Map parameters, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.initBinaryIndex(numDocs, dim, parameters); + } else { + return FaissService.initIndex(numDocs, dim, parameters); + } + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "initIndexFromScratch not supported for provided engine : %s", knnEngine.getName()) + ); + } + + /** + * Inserts to a faiss index. + * + * @param docs ids of documents + * @param vectorsAddress address of native memory where vectors are stored + * @param dimension dimension of the vector to be indexed + * @param parameters parameters to build index + * @param indexAddress address of native memory where index is stored + * @param knnEngine knn engine + */ + public static void insertToIndex( + int[] docs, + long vectorsAddress, + int dimension, + Map parameters, + long indexAddress, + KNNEngine knnEngine + ) { + int threadCount = (int) parameters.getOrDefault(KNNConstants.INDEX_THREAD_QTY, 0); + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.insertToBinaryIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } else { + FaissService.insertToIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } + return; + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "insertToIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + + /** + * Writes a faiss index to disk. + * + * @param indexPath path to save index to + * @param indexAddress address of native memory where index is stored + * @param knnEngine knn engine + * @param parameters parameters to build index + */ + public static void writeIndex(String indexPath, long indexAddress, KNNEngine knnEngine, Map parameters) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.writeBinaryIndex(indexAddress, indexPath); + } else { + FaissService.writeIndex(indexAddress, indexPath); + } + return; + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "writeIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer @@ -50,16 +132,9 @@ public static void createIndex( return; } - if (KNNEngine.FAISS == knnEngine) { - if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { - FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); - } else { - FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); - } - return; - } - - throw new IllegalArgumentException(String.format("CreateIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "CreateIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -93,7 +168,7 @@ public static void createIndexFromTemplate( } throw new IllegalArgumentException( - String.format("CreateIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "CreateIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) ); } @@ -118,7 +193,9 @@ public static long loadIndex(String indexPath, Map parameters, K } } - throw new IllegalArgumentException(String.format("LoadIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "LoadIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -150,7 +227,7 @@ public static long initSharedIndexState(long indexAddr, KNNEngine knnEngine) { return FaissService.initSharedIndexState(indexAddr); } throw new IllegalArgumentException( - String.format("InitSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "InitSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -168,7 +245,7 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr, } throw new IllegalArgumentException( - String.format("SetSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "SetSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -216,7 +293,9 @@ public static KNNQueryResult[] queryIndex( } return FaissService.queryIndex(indexPointer, queryVector, k, methodParameters, parentIds); } - throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "QueryIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -252,7 +331,9 @@ public static KNNQueryResult[] queryBinaryIndex( parentIds ); } - throw new IllegalArgumentException(String.format("QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -283,7 +364,7 @@ public static void free(final long indexPointer, final KNNEngine knnEngine, fina return; } - throw new IllegalArgumentException(String.format("Free not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Free not supported for provided engine : %s", knnEngine.getName())); } /** @@ -298,7 +379,7 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE return; } throw new IllegalArgumentException( - String.format("FreeSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "FreeSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -319,7 +400,9 @@ public static byte[] trainIndex(Map indexParameters, int dimensi return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } - throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "TrainIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -377,6 +460,6 @@ public static KNNQueryResult[] radiusQueryIndex( } return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, methodParameters, indexMaxResultWindow, parentIds); } - throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); + throw new IllegalArgumentException(String.format(Locale.ROOT, "RadiusQueryIndex not supported for provided engine")); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 47dd1dda9..5a68d96d4 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -36,23 +36,20 @@ public void testGetPair_whenCalled_thenReturn() { when(binaryDocValues.binaryValue()).thenReturn(bytesRef); VectorTransfer vectorTransfer = mock(VectorTransfer.class); - when(vectorTransfer.getSerializationMode(any(BytesRef.class))).thenReturn(SerializationMode.COLLECTIONS_OF_BYTES); when(vectorTransfer.getVectorAddress()).thenReturn(vectorAddress); when(vectorTransfer.getDimension()).thenReturn(dimension); // Run - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(binaryDocValues, vectorTransfer); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(binaryDocValues, vectorTransfer, false); // Verify verify(vectorTransfer).init(liveDocCount); - verify(vectorTransfer).getSerializationMode(any(BytesRef.class)); verify(vectorTransfer).transfer(any(BytesRef.class)); verify(vectorTransfer).close(); - assertTrue(Arrays.equals(docId, pair.docs)); - assertEquals(vectorAddress, pair.getVectorAddress()); - assertEquals(dimension, pair.getDimension()); - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); + assertTrue(Arrays.equals(docId, batch.docs)); + assertEquals(vectorAddress, batch.getVectorAddress()); + assertEquals(dimension, batch.getDimension()); } public void testCalculateArraySize() { diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index dc9a97fbf..316582f6c 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -14,6 +14,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.VectorDataType; @@ -56,7 +57,7 @@ public void testIndexAllocation_close() throws InterruptedException { } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); @@ -117,7 +118,7 @@ public void testClose_whenBinaryFiass_thenSuccess() { VectorDataType.BINARY.getValue() ); long vectorMemoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors * dataLength); - JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 51f95d29a..8a38cadb5 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -15,6 +15,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNICommons; @@ -32,6 +33,8 @@ import java.util.Arrays; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; @@ -56,7 +59,7 @@ public void testIndexLoadStrategy_load() throws IOException { } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); @@ -104,7 +107,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException { VectorDataType.BINARY.getValue() ); long memoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors); - JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index ae9ad7106..e8c7c9488 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -85,7 +85,7 @@ public static void setUpClass() throws IOException { public void testCreateIndex_invalid_engineNotSupported() { expectThrows( IllegalArgumentException.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( new int[] {}, 0, 0, @@ -99,21 +99,14 @@ public void testCreateIndex_invalid_engineNotSupported() { public void testCreateIndex_invalid_engineNull() { expectThrows( Exception.class, - () -> JNIService.createIndex( - new int[] {}, - 0, - 0, - "test", - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - null - ) + () -> TestUtils.createIndex(new int[] {}, 0, 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null) ); } public void testCreateIndex_nmslib_invalid_noSpaceType() { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -132,7 +125,7 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept Path tmpFile1 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors1[0].length, @@ -148,7 +141,7 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept Path tmpFile2 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress2, vectors2[0].length, @@ -167,7 +160,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( null, memoryAddress, 0, @@ -179,7 +172,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, 0, 0, @@ -191,7 +184,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, 0, @@ -203,12 +196,12 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) + () -> TestUtils.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) ); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, 0, @@ -227,7 +220,7 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -253,7 +246,7 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -273,7 +266,7 @@ public void testCreateIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -285,7 +278,7 @@ public void testCreateIndex_nmslib_valid() throws IOException { tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -309,7 +302,7 @@ public void testCreateIndex_faiss_invalid_noSpaceType() { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -328,7 +321,7 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti Path tmpFile1 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors1[0].length, @@ -343,7 +336,7 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti Path tmpFile2 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors2[0].length, @@ -363,7 +356,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( null, memoryAddress, 0, @@ -375,7 +368,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, 0, 0, @@ -387,7 +380,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -399,7 +392,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -411,7 +404,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -431,7 +424,7 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -451,7 +444,7 @@ public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOExceptio Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -469,7 +462,7 @@ public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOExc Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -492,7 +485,7 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -516,7 +509,7 @@ public void testLoadIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -539,7 +532,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddress, testData.indexData.getDimension(), @@ -627,7 +620,7 @@ public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOExcept Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -653,7 +646,7 @@ public void testCreateIndex_faiss_valid() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile1 = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -670,7 +663,7 @@ public void testCreateIndex_faiss_valid() throws IOException { public void testCreateIndex_binary_faiss_valid() { Path tmpFile1 = createTempFile(); long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddr, testData.indexData.getDimension(), @@ -726,7 +719,7 @@ public void testLoadIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -762,7 +755,7 @@ public void testLoadIndex_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -792,7 +785,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -822,7 +815,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -855,7 +848,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -881,7 +874,7 @@ public void testQueryIndex_faiss_valid() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -942,7 +935,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testDataNested.indexData.docs, testData.loadDataToMemoryAddress(), testDataNested.indexData.getDimension(), @@ -985,7 +978,7 @@ public void testQueryBinaryIndex_faiss_valid() { for (String method : methods) { Path tmpFile = createTempFile(); long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddr, testData.indexData.getDimension(), @@ -1064,7 +1057,7 @@ public void testFree_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -1088,7 +1081,7 @@ public void testFree_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -1211,7 +1204,7 @@ private long transferVectors(int numDuplicates) { return trainPointer1; } - public void testCreateIndexFromTemplate() throws IOException { + public void createIndexFromTemplate() throws IOException { long trainPointer1 = JNIService.transferVectors(0, testData.indexData.vectors); assertNotEquals(0, trainPointer1); @@ -1412,7 +1405,7 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac private String createFaissHNSWIndex(SpaceType spaceType) throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 6676ee154..e2b831e6e 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -19,7 +19,9 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.util.Collections; @@ -414,4 +416,15 @@ public static class Pair { public float[][] vectors; } } + + public static void createIndex(int[] ids, long address, int dimension, String name, Map parameters, KNNEngine engine) { + if (engine != KNNEngine.FAISS) { + JNIService.createIndex(ids, address, dimension, name, parameters, engine); + } else { + // We can initialize numDocs as 0, this will just not reserve anything. + long indexAddress = JNIService.initIndexFromScratch(0, dimension, parameters, engine); + JNIService.insertToIndex(ids, address, dimension, parameters, indexAddress, engine); + JNIService.writeIndex(name, indexAddress, engine, parameters); + } + } }