From 7cf45c8cfd4219912d67becb731f8a7af61410c0 Mon Sep 17 00:00:00 2001 From: Doo Yong Kim <0ctopus13prime@gmail.com> Date: Tue, 15 Oct 2024 09:41:14 -0700 Subject: [PATCH] Introduce a loading layer in NMSLIB. (#2185) * Introduce a loading layer in NMSLIB. Signed-off-by: Dooyong Kim * Added NMSLIB istream implementation. Signed-off-by: Dooyong Kim * Fix integer overflow issue when passing read size for loading NMSLIB vector index. Signed-off-by: Dooyong Kim * Added unit test for NMSLIB loading layer. Signed-off-by: Dooyong Kim * Made a patch in NMSLIB to avoid frequently calling JNI for better loading index performance. Signed-off-by: Dooyong Kim * Compliance constexpr function in C++11 having nullstatement. Signed-off-by: Dooyong Kim --------- Signed-off-by: Dooyong Kim Co-authored-by: Dooyong Kim --- jni/CMakeLists.txt | 3 +- jni/cmake/init-nmslib.cmake | 3 +- jni/include/faiss_stream_support.h | 81 +-- jni/include/jni_util.h | 9 +- jni/include/native_engines_stream_support.h | 125 ++++ jni/include/nmslib_stream_support.h | 51 ++ jni/include/nmslib_wrapper.h | 8 + .../org_opensearch_knn_jni_NmslibService.h | 8 + ...pis-for-vector-index-loading-in-Hnsw.patch | 221 ++++++++ ...is-using-stream-to-load-save-in-Hnsw.patch | 93 --- jni/src/jni_util.cpp | 10 +- jni/src/nmslib_wrapper.cpp | 533 ++++++++++-------- .../org_opensearch_knn_jni_NmslibService.cpp | 117 ++-- jni/tests/faiss_stream_support_test.cpp | 98 ++-- jni/tests/native_stream_support_util.h | 102 ++++ jni/tests/nmslib_stream_support_test.cpp | 120 ++++ jni/tests/test_util.h | 3 +- .../index/memory/NativeMemoryAllocation.java | 4 +- .../memory/NativeMemoryLoadStrategy.java | 27 +- .../knn/index/store/IndexInputWithBuffer.java | 10 +- .../org/opensearch/knn/jni/JNIService.java | 2 + .../org/opensearch/knn/jni/NmslibService.java | 10 + .../knn/index/KNNCircuitBreakerIT.java | 2 +- .../opensearch/knn/jni/JNIServiceTests.java | 25 + 24 files changed, 1103 insertions(+), 562 deletions(-) create mode 100644 jni/include/native_engines_stream_support.h create mode 100644 jni/include/nmslib_stream_support.h create mode 100644 jni/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch delete mode 100644 jni/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch create mode 100644 jni/tests/native_stream_support_util.h create mode 100644 jni/tests/nmslib_stream_support_test.cpp diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 4caa907b3..1920453c7 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -156,7 +156,8 @@ if ("${WIN32}" STREQUAL "") tests/commons_test.cpp tests/faiss_stream_support_test.cpp tests/faiss_index_service_test.cpp - ) + tests/nmslib_stream_support_test.cpp + ) target_link_libraries( jni_test diff --git a/jni/cmake/init-nmslib.cmake b/jni/cmake/init-nmslib.cmake index 64df457c1..2554b2bd7 100644 --- a/jni/cmake/init-nmslib.cmake +++ b/jni/cmake/init-nmslib.cmake @@ -12,14 +12,13 @@ if (NOT EXISTS ${NMS_REPO_DIR}) execute_process(COMMAND git submodule update --init -- external/nmslib WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif () - # Apply patches if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true) # Define list of patch files set(PATCH_FILE_LIST) list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch") - list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch") + list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch") # Get patch id of the last commit execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib) diff --git a/jni/include/faiss_stream_support.h b/jni/include/faiss_stream_support.h index 65f1631d4..a12d66ae9 100644 --- a/jni/include/faiss_stream_support.h +++ b/jni/include/faiss_stream_support.h @@ -9,11 +9,12 @@ * GitHub history for details. */ -#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H -#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H +#ifndef OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H +#define OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H #include "faiss/impl/io.h" #include "jni_util.h" +#include "native_engines_stream_support.h" #include #include @@ -23,80 +24,6 @@ namespace knn_jni { namespace stream { -/** - * This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer. - */ - -class NativeEngineIndexInputMediator { - public: - // Expect IndexInputWithBuffer is given as `_indexInput`. - NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface, - JNIEnv *_env, - jobject _indexInput) - : jni_interface(_jni_interface), - env(_env), - indexInput(_indexInput), - bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env, - _indexInput, - getBufferFieldId(_jni_interface, _env)))), - copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) { - } - - void copyBytes(int64_t nbytes, uint8_t *destination) { - while (nbytes > 0) { - // Call `copyBytes` to read bytes as many as possible. - const auto readBytes = - jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes); - - // === Critical Section Start === - - // Get primitive array pointer, no copy is happening in OpenJDK. - auto primitiveArray = - (jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr); - - // Copy Java bytes to C++ destination address. - std::memcpy(destination, primitiveArray, readBytes); - - // Release the acquired primitive array pointer. - // JNI_ABORT tells JVM to directly free memory without copying back to Java byte[]. - // Since we're merely copying data, we don't need to copying back. - jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT); - - // === Critical Section End === - - destination += readBytes; - nbytes -= readBytes; - } // End while - } - - private: - static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) { - static jclass INDEX_INPUT_WITH_BUFFER_CLASS = - jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer"); - return INDEX_INPUT_WITH_BUFFER_CLASS; - } - - static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) { - static jmethodID COPY_METHOD_ID = - jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I"); - return COPY_METHOD_ID; - } - - static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) { - static jfieldID BUFFER_FIELD_ID = - jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B"); - return BUFFER_FIELD_ID; - } - - JNIUtilInterface *jni_interface; - JNIEnv *env; - - // `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading. - jobject indexInput; - jbyteArray bufferArray; - jmethodID copyBytesMethod; -}; // class NativeEngineIndexInputMediator - /** @@ -133,4 +60,4 @@ class FaissOpenSearchIOReader final : public faiss::IOReader { } } -#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H +#endif //OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 6b1b926e7..9f4daef7c 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -138,7 +138,11 @@ namespace knn_jni { virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0; - virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0; + virtual jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz, + jmethodID methodID, jvalue *args) = 0; + + virtual jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, + jmethodID methodID, jvalue* args) = 0; // -------------------------------------------------------------------------- }; @@ -194,7 +198,8 @@ namespace knn_jni { jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final; jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; - jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final; + jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, jvalue *args) final; + jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) final; void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final; void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final; diff --git a/jni/include/native_engines_stream_support.h b/jni/include/native_engines_stream_support.h new file mode 100644 index 000000000..5d4b32d3d --- /dev/null +++ b/jni/include/native_engines_stream_support.h @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H +#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H + +#include "jni_util.h" + +#include +#include +#include +#include + +namespace knn_jni { +namespace stream { + + + +/** + * This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer. + */ +class NativeEngineIndexInputMediator { + public: + // Expect IndexInputWithBuffer is given as `_indexInput`. + NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface, + JNIEnv *_env, + jobject _indexInput) + : jni_interface(_jni_interface), + env(_env), + indexInput(_indexInput), + bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env, + _indexInput, + getBufferFieldId(_jni_interface, _env)))), + copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)), + remainingBytesMethod(getRemainingBytesMethod(_jni_interface, _env)) { + } + + void copyBytes(int64_t nbytes, uint8_t *destination) { + auto jclazz = getIndexInputWithBufferClass(jni_interface, env); + + while (nbytes > 0) { + // Call `copyBytes` to read bytes as many as possible. + jvalue args; + args.j = nbytes; + const auto readBytes = + jni_interface->CallNonvirtualIntMethodA(env, indexInput, jclazz, copyBytesMethod, &args); + + // === Critical Section Start === + + // Get primitive array pointer, no copy is happening in OpenJDK. + auto primitiveArray = + (jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr); + + // Copy Java bytes to C++ destination address. + std::memcpy(destination, primitiveArray, readBytes); + + // Release the acquired primitive array pointer. + // JNI_ABORT tells JVM to directly free memory without copying back to Java byte[]. + // Since we're merely copying data, we don't need to copying back. + jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT); + + // === Critical Section End === + + destination += readBytes; + nbytes -= readBytes; + } // End while + } + + int64_t remainingBytes() { + return jni_interface->CallNonvirtualLongMethodA(env, + indexInput, + getIndexInputWithBufferClass(jni_interface, env), + remainingBytesMethod, + nullptr); + } + + private: + static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jclass INDEX_INPUT_WITH_BUFFER_CLASS = + jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer"); + return INDEX_INPUT_WITH_BUFFER_CLASS; + } + + static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jmethodID COPY_METHOD_ID = + jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I"); + return COPY_METHOD_ID; + } + + static jmethodID getRemainingBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jmethodID COPY_METHOD_ID = + jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "remainingBytes", "()J"); + return COPY_METHOD_ID; + } + + static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jfieldID BUFFER_FIELD_ID = + jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B"); + return BUFFER_FIELD_ID; + } + + JNIUtilInterface *jni_interface; + JNIEnv *env; + + // `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading. + jobject indexInput; + jbyteArray bufferArray; + jmethodID copyBytesMethod; + jmethodID remainingBytesMethod; +}; // class NativeEngineIndexInputMediator + + + +} +} + +#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H diff --git a/jni/include/nmslib_stream_support.h b/jni/include/nmslib_stream_support.h new file mode 100644 index 000000000..38c06cb95 --- /dev/null +++ b/jni/include/nmslib_stream_support.h @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#ifndef OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H +#define OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H + +#include "native_engines_stream_support.h" + +namespace knn_jni { +namespace stream { + + + +/** + * NmslibIOReader implementation delegating NativeEngineIndexInputMediator to read bytes. + */ +class NmslibOpenSearchIOReader final : public similarity::NmslibIOReader { + public: + explicit NmslibOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator) + : mediator(_mediator) { + } + + void read(char *bytes, size_t len) final { + if (len > 0) { + // Mediator calls IndexInput, then copy read bytes to `ptr`. + mediator->copyBytes(len, (uint8_t *) bytes); + } + } + + size_t remainingBytes() final { + return mediator->remainingBytes(); + } + + private: + NativeEngineIndexInputMediator *mediator; +}; // class NmslibOpenSearchIOReader + + + +} +} + +#endif //OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H diff --git a/jni/include/nmslib_wrapper.h b/jni/include/nmslib_wrapper.h index 27a013c10..2853cd71f 100644 --- a/jni/include/nmslib_wrapper.h +++ b/jni/include/nmslib_wrapper.h @@ -33,6 +33,14 @@ namespace knn_jni { // Return a pointer to the loaded index jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ, jobject parametersJ); + // Load an index via an input stream into memory. Use parametersJ to set any query time parameters + // + // Return a pointer to the loaded index + jlong LoadIndexWithStream(knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + jobject readStream, + jobject parametersJ); + // Execute a query against the index located in memory at indexPointerJ. // // Return an array of KNNQueryResults diff --git a/jni/include/org_opensearch_knn_jni_NmslibService.h b/jni/include/org_opensearch_knn_jni_NmslibService.h index a9d5238b7..8d6633aff 100644 --- a/jni/include/org_opensearch_knn_jni_NmslibService.h +++ b/jni/include/org_opensearch_knn_jni_NmslibService.h @@ -34,6 +34,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex (JNIEnv *, jclass, jstring, jobject); +/* + * Class: org_opensearch_knn_jni_NmslibService + * Method: loadIndexWithStream + * Signature: (Lorg/opensearch/knn/index/store/IndexInputWithBuffer;Ljava/util/Map;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndexWithStream + (JNIEnv *, jclass, jobject, jobject); + /* * Class: org_opensearch_knn_jni_NmslibService * Method: queryIndex diff --git a/jni/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch b/jni/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch new file mode 100644 index 000000000..55e7a8c81 --- /dev/null +++ b/jni/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch @@ -0,0 +1,221 @@ +From 2e9b7f7117842009e081dd79e8ab8b019122a3de Mon Sep 17 00:00:00 2001 +From: Dooyong Kim +Date: Fri, 11 Oct 2024 16:19:45 -0700 +Subject: [PATCH] Added streaming apis for vector index loading in Hnsw. + +Signed-off-by: Dooyong Kim +--- + similarity_search/include/method/hnsw.h | 3 + + similarity_search/include/utils.h | 12 +++ + similarity_search/src/method/hnsw.cc | 138 +++++++++++++++++++++++- + 3 files changed, 152 insertions(+), 1 deletion(-) + +diff --git a/similarity_search/include/method/hnsw.h b/similarity_search/include/method/hnsw.h +index e6dcea7..433f98f 100644 +--- a/similarity_search/include/method/hnsw.h ++++ b/similarity_search/include/method/hnsw.h +@@ -457,6 +457,8 @@ namespace similarity { + + virtual void LoadIndex(const string &location) override; + ++ void LoadIndexWithStream(similarity::NmslibIOReader& in); ++ + Hnsw(bool PrintProgress, const Space &space, const ObjectVector &data); + void CreateIndex(const AnyParams &IndexParams) override; + +@@ -500,6 +502,7 @@ namespace similarity { + + void SaveOptimizedIndex(std::ostream& output); + void LoadOptimizedIndex(std::istream& input); ++ void LoadOptimizedIndex(NmslibIOReader& input); + + void SaveRegularIndexBin(std::ostream& output); + void LoadRegularIndexBin(std::istream& input); +diff --git a/similarity_search/include/utils.h b/similarity_search/include/utils.h +index b521c26..a3931b7 100644 +--- a/similarity_search/include/utils.h ++++ b/similarity_search/include/utils.h +@@ -299,12 +299,24 @@ inline void WriteField(ostream& out, const string& fieldName, const FieldType& f + } + } + ++struct NmslibIOReader { ++ virtual ~NmslibIOReader() = default; ++ ++ virtual void read(char* bytes, size_t len) = 0; ++ ++ virtual size_t remainingBytes() = 0; ++}; + + template + void writeBinaryPOD(ostream& out, const T& podRef) { + out.write((char*)&podRef, sizeof(T)); + } + ++template ++static void readBinaryPOD(NmslibIOReader& in, T& podRef) { ++ in.read((char*)&podRef, sizeof(T)); ++} ++ + template + static void readBinaryPOD(istream& in, T& podRef) { + in.read((char*)&podRef, sizeof(T)); +diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc +index 4080b3b..662f06c 100644 +--- a/similarity_search/src/method/hnsw.cc ++++ b/similarity_search/src/method/hnsw.cc +@@ -950,7 +950,6 @@ namespace similarity { + " read so far doesn't match the number of read lines: " + ConvertToString(lineNum)); + } + +- + template + void + Hnsw::LoadRegularIndexBin(std::istream& input) { +@@ -1034,6 +1033,143 @@ namespace similarity { + + } + ++ constexpr bool _isLittleEndian() { ++ return (((uint32_t) 1) & 0xFFU) == 1; ++ } ++ ++ SIZEMASS_TYPE _readIntBigEndian(uint8_t byte0, uint8_t byte1, uint8_t byte2, uint8_t byte3) noexcept { ++ return (static_cast(byte0) << 24) | ++ (static_cast(byte1) << 16) | ++ (static_cast(byte2) << 8) | ++ static_cast(byte3); ++ } ++ ++ SIZEMASS_TYPE _readIntLittleEndian(uint8_t byte0, uint8_t byte1, uint8_t byte2, uint8_t byte3) noexcept { ++ return (static_cast(byte3) << 24) | ++ (static_cast(byte2) << 16) | ++ (static_cast(byte1) << 8) | ++ static_cast(byte0); ++ } ++ ++ template ++ void Hnsw::LoadIndexWithStream(NmslibIOReader& input) { ++ LOG(LIB_INFO) << "Loading index from an input stream(NmslibIOReader)."; ++ ++ unsigned int optimIndexFlag= 0; ++ readBinaryPOD(input, optimIndexFlag); ++ ++ if (!optimIndexFlag) { ++ throw std::runtime_error("With stream, we only support optimized index type."); ++ } else { ++ LoadOptimizedIndex(input); ++ } ++ ++ LOG(LIB_INFO) << "Finished loading index"; ++ visitedlistpool = new VisitedListPool(1, totalElementsStored_); ++ } ++ ++ template ++ void Hnsw::LoadOptimizedIndex(NmslibIOReader& input) { ++ static_assert(sizeof(SIZEMASS_TYPE) == 4, "Expected sizeof(SIZEMASS_TYPE) == 4."); ++ ++ LOG(LIB_INFO) << "Loading optimized index(NmslibIOReader)."; ++ ++ readBinaryPOD(input, totalElementsStored_); ++ readBinaryPOD(input, memoryPerObject_); ++ readBinaryPOD(input, offsetLevel0_); ++ readBinaryPOD(input, offsetData_); ++ readBinaryPOD(input, maxlevel_); ++ readBinaryPOD(input, enterpointId_); ++ readBinaryPOD(input, maxM_); ++ readBinaryPOD(input, maxM0_); ++ readBinaryPOD(input, dist_func_type_); ++ readBinaryPOD(input, searchMethod_); ++ ++ LOG(LIB_INFO) << "searchMethod: " << searchMethod_; ++ ++ fstdistfunc_ = getDistFunc(dist_func_type_); ++ iscosine_ = (dist_func_type_ == kNormCosine); ++ CHECK_MSG(fstdistfunc_ != nullptr, "Unknown distance function code: " + ConvertToString(dist_func_type_)); ++ ++ LOG(LIB_INFO) << "Total: " << totalElementsStored_ << ", Memory per object: " << memoryPerObject_; ++ size_t data_plus_links0_size = memoryPerObject_ * totalElementsStored_; ++ ++ // we allocate a few extra bytes to prevent prefetch from accessing out of range memory ++ data_level0_memory_ = (char *)malloc(data_plus_links0_size + EXTRA_MEM_PAD_SIZE); ++ CHECK(data_level0_memory_); ++ input.read(data_level0_memory_, data_plus_links0_size); ++ // we allocate a few extra bytes to prevent prefetch from accessing out of range memory ++ linkLists_ = (char **)malloc( (sizeof(void *) * totalElementsStored_) + EXTRA_MEM_PAD_SIZE); ++ CHECK(linkLists_); ++ ++ data_rearranged_.resize(totalElementsStored_); ++ ++ const size_t bufferSize = 64 * 1024; // 64KB ++ std::unique_ptr buffer (new char[bufferSize]); ++ uint32_t end = 0; ++ uint32_t pos = 0; ++ constexpr bool isLittleEndian = _isLittleEndian(); ++ ++ for (size_t i = 0, remainingBytes = input.remainingBytes(); i < totalElementsStored_; i++) { ++ if ((pos + sizeof(SIZEMASS_TYPE)) >= end) { ++ // Underflow during reading an integer size field. ++ // So the idea is to move the first partial bytes (which is < 4 bytes) to the beginning section of ++ // buffer. ++ // Ex: buffer -> [..., b0, b1] where we only have two bytes and still need to read two bytes more ++ // buffer -> [b0, b1, ...] after move the first part. firstPartLen = 2. ++ const auto firstPartLen = end - pos; ++ if (firstPartLen > 0) { ++ std::memcpy(buffer.get(), buffer.get() + pos, firstPartLen); ++ } ++ // Then, bulk load bytes from input stream. Note that the first few bytes are already occupied by ++ // earlier moving logic, hence required bytes are bufferSize - firstPartLen. ++ const auto copyBytes = std::min(remainingBytes, bufferSize - firstPartLen); ++ input.read(buffer.get() + firstPartLen, copyBytes); ++ remainingBytes -= copyBytes; ++ end = copyBytes + firstPartLen; ++ pos = 0; ++ } ++ ++ // Read data size field. ++ // Since NMSLIB directly write 4 bytes integer casting to char*, bytes outline may differ among systems. ++ SIZEMASS_TYPE linkListSize = 0; ++ if (isLittleEndian) { ++ linkListSize = _readIntLittleEndian(buffer[pos], buffer[pos + 1], buffer[pos + 2], buffer[pos + 3]); ++ } else { ++ linkListSize = _readIntBigEndian(buffer[pos], buffer[pos + 1], buffer[pos + 2], buffer[pos + 3]); ++ } ++ pos += sizeof(SIZEMASS_TYPE); ++ ++ if (linkListSize == 0) { ++ linkLists_[i] = nullptr; ++ } else { ++ linkLists_[i] = (char *)malloc(linkListSize); ++ CHECK(linkLists_[i]); ++ ++ SIZEMASS_TYPE leftLinkListData = linkListSize; ++ auto dataPtr = linkLists_[i]; ++ while (leftLinkListData > 0) { ++ if (pos >= end) { ++ // Underflow during read linked list bytes. ++ const auto copyBytes = std::min(remainingBytes, bufferSize); ++ input.read(buffer.get(), copyBytes); ++ remainingBytes -= copyBytes; ++ end = copyBytes; ++ pos = 0; ++ } ++ ++ // Read linked list bytes. ++ const auto copyBytes = std::min(leftLinkListData, end - pos); ++ std::memcpy(dataPtr, buffer.get() + pos, copyBytes); ++ dataPtr += copyBytes; ++ leftLinkListData -= copyBytes; ++ pos += copyBytes; ++ } // End while ++ } // End if ++ ++ data_rearranged_[i] = new Object(data_level0_memory_ + (i)*memoryPerObject_ + offsetData_); ++ } // End for ++ } + + template + void +-- +2.39.5 (Apple Git-154) + diff --git a/jni/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch b/jni/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch deleted file mode 100644 index bbba329b4..000000000 --- a/jni/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch +++ /dev/null @@ -1,93 +0,0 @@ -From 7e099ec111e5c9db4b243da249c73f0ecc206281 Mon Sep 17 00:00:00 2001 -From: Dooyong Kim -Date: Thu, 26 Sep 2024 15:20:53 -0700 -Subject: [PATCH] Adding two apis using stream to load/save in Hnsw. - -Signed-off-by: Dooyong Kim ---- - similarity_search/include/method/hnsw.h | 4 +++ - similarity_search/src/method/hnsw.cc | 44 +++++++++++++++++++++++++ - 2 files changed, 48 insertions(+) - -diff --git a/similarity_search/include/method/hnsw.h b/similarity_search/include/method/hnsw.h -index 57d99d0..7ff3f3d 100644 ---- a/similarity_search/include/method/hnsw.h -+++ b/similarity_search/include/method/hnsw.h -@@ -455,8 +455,12 @@ namespace similarity { - public: - virtual void SaveIndex(const string &location) override; - -+ void SaveIndexWithStream(std::ostream& output); -+ - virtual void LoadIndex(const string &location) override; - -+ void LoadIndexWithStream(std::istream& in); -+ - Hnsw(bool PrintProgress, const Space &space, const ObjectVector &data); - void CreateIndex(const AnyParams &IndexParams) override; - -diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc -index 35b372c..e7a2c9e 100644 ---- a/similarity_search/src/method/hnsw.cc -+++ b/similarity_search/src/method/hnsw.cc -@@ -771,6 +771,25 @@ namespace similarity { - output.close(); - } - -+ template -+ void Hnsw::SaveIndexWithStream(std::ostream &output) { -+ output.exceptions(ios::badbit | ios::failbit); -+ -+ unsigned int optimIndexFlag = data_level0_memory_ != nullptr; -+ -+ writeBinaryPOD(output, optimIndexFlag); -+ -+ if (!optimIndexFlag) { -+#if USE_TEXT_REGULAR_INDEX -+ SaveRegularIndexText(output); -+#else -+ SaveRegularIndexBin(output); -+#endif -+ } else { -+ SaveOptimizedIndex(output); -+ } -+ } -+ - template - void - Hnsw::SaveOptimizedIndex(std::ostream& output) { -@@ -1021,6 +1040,31 @@ namespace similarity { - - } - -+ template -+ void Hnsw::LoadIndexWithStream(std::istream& input) { -+ LOG(LIB_INFO) << "Loading index from an input stream."; -+ CHECK_MSG(input, "Cannot open file for reading with an input stream"); -+ -+ input.exceptions(ios::badbit | ios::failbit); -+ -+#if USE_TEXT_REGULAR_INDEX -+ LoadRegularIndexText(input); -+#else -+ unsigned int optimIndexFlag= 0; -+ -+ readBinaryPOD(input, optimIndexFlag); -+ -+ if (!optimIndexFlag) { -+ LoadRegularIndexBin(input); -+ } else { -+ LoadOptimizedIndex(input); -+ } -+#endif -+ -+ LOG(LIB_INFO) << "Finished loading index"; -+ visitedlistpool = new VisitedListPool(1, totalElementsStored_); -+ } -+ - - template - void --- -2.39.5 (Apple Git-154) - diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index 3eaf3b0a1..8dc818c94 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -563,8 +563,14 @@ jfieldID knn_jni::JNIUtil::GetFieldID(JNIEnv * env, jclass clazz, const char *na return env->GetFieldID(clazz, name, sig); } -jint knn_jni::JNIUtil::CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) { - return env->CallIntMethod(obj, methodID, longArg); +jint knn_jni::JNIUtil::CallNonvirtualIntMethodA(JNIEnv * env, jobject obj, jclass clazz, + jmethodID methodID, jvalue* args) { + return env->CallNonvirtualIntMethodA(obj, clazz, methodID, args); +} + +jlong knn_jni::JNIUtil::CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, + jmethodID methodID, jvalue* args) { + return env->CallNonvirtualLongMethodA(obj, clazz, methodID, args); } void * knn_jni::JNIUtil::GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) { diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index 21b34eb83..536558caa 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -11,6 +11,7 @@ #include "jni_util.h" #include "nmslib_wrapper.h" +#include "nmslib_stream_support.h" #include "commons.h" @@ -25,303 +26,365 @@ #include #include +#include #include "hnswquery.h" +#include "method/hnsw.h" - -std::string TranslateSpaceType(const std::string& spaceType); +std::string TranslateSpaceType(const std::string &spaceType); // We do not use label functionality of nmslib so we pass default label. Setting as a const allows us to avoid a few // allocations const similarity::LabelType DEFAULT_LABEL = -1; -void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, +void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) { - if (idsJ == nullptr) { - throw std::runtime_error("IDs cannot be null"); - } + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } - if (vectorsAddressJ <= 0) { - throw std::runtime_error("VectorsAddress cannot be less than 0"); - } + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } - if(dimJ <= 0) { - throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); - } + if (dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); - } + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } - if (parametersJ == nullptr) { - throw std::runtime_error("Parameters cannot be null"); - } - - // Handle parameters - auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); - std::vector indexParameters; + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } - // Algorithm parameters will be in a sub map - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { - jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; - auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + // Handle parameters + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + std::vector indexParameters; - if(subParametersCpp.find(knn_jni::EF_CONSTRUCTION) != subParametersCpp.end()) { - auto efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::EF_CONSTRUCTION]); - indexParameters.push_back(knn_jni::EF_CONSTRUCTION_NMSLIB + "=" + std::to_string(efConstruction)); - } + // Algorithm parameters will be in a sub map + if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); - if(subParametersCpp.find(knn_jni::M) != subParametersCpp.end()) { - auto m = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::M]); - indexParameters.push_back(knn_jni::M_NMSLIB + "=" + std::to_string(m)); - } - - jniUtil->DeleteLocalRef(env, subParametersJ); + if (subParametersCpp.find(knn_jni::EF_CONSTRUCTION) != subParametersCpp.end()) { + auto efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::EF_CONSTRUCTION]); + indexParameters.push_back(knn_jni::EF_CONSTRUCTION_NMSLIB + "=" + std::to_string(efConstruction)); } - if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { - auto indexThreadQty = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); - indexParameters.push_back(knn_jni::INDEX_THREAD_QUANTITY + "=" + std::to_string(indexThreadQty)); + if (subParametersCpp.find(knn_jni::M) != subParametersCpp.end()) { + auto m = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::M]); + indexParameters.push_back(knn_jni::M_NMSLIB + "=" + std::to_string(m)); } - jniUtil->DeleteLocalRef(env, parametersJ); - - // Get the path to save the index - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - - // Get space type for this index - jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); - std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); - spaceTypeCpp = TranslateSpaceType(spaceTypeCpp); - - std::unique_ptr> space; - space.reset(similarity::SpaceFactoryRegistry::Instance().CreateSpace(spaceTypeCpp,similarity::AnyParams())); - - // Get number of ids and vectors and dimension - auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); - int dim = (int)dimJ; - // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value - int numVectors = (int) ( inputVectors->size() / (uint64_t) dim); - if(numVectors == 0) { - throw std::runtime_error("Number of vectors cannot be 0"); + jniUtil->DeleteLocalRef(env, subParametersJ); + } + + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto indexThreadQty = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + indexParameters.push_back(knn_jni::INDEX_THREAD_QUANTITY + "=" + std::to_string(indexThreadQty)); + } + + jniUtil->DeleteLocalRef(env, parametersJ); + + // Get the path to save the index + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + spaceTypeCpp = TranslateSpaceType(spaceTypeCpp); + + std::unique_ptr> space; + space.reset(similarity::SpaceFactoryRegistry::Instance().CreateSpace(spaceTypeCpp, similarity::AnyParams())); + + // Get number of ids and vectors and dimension + auto *inputVectors = reinterpret_cast *>(vectorsAddressJ); + int dim = (int) dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if (numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Read dataset + similarity::ObjectVector dataset; + dataset.reserve(numVectors); + int *idsCpp; + try { + // Read in data set + idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr); + size_t vectorSizeInBytes = dim * sizeof(float); + // vectorPointer needs to be unsigned long long, this will ensure that out of range doesn't happen for this pointer + // when the values of numVectors * dim becomes very large. + // Example: for 10M vectors of 1536 dim vectorPointer max value will be ~15.3B which is already > range of ints. + // keeping it unsigned long long we will never go above the range. + unsigned long long vectorPointer = 0; + + // Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as + // opposed to individually will prevent heap fragmentation. We have observed that allocating individual + // objects causes RSS to rise throughout the lifetime of a process + // (see https://github.com/opensearch-project/k-NN/issues/772 and + // https://github.com/opensearch-project/k-NN/issues/72). This is because, in typical systems, small + // allocations will reside on some kind of heap managed by an allocator. Once freed, the allocator does not + // always return the memory to the OS. If the heap gets fragmented, this will cause the allocator + // to ask for more memory, causing RSS to grow. On large allocations (> 128 kb), most allocators will + // internally use mmap. Once freed, unmap will be called, which will immediately return memory to the OS + // which in turn prevents RSS from growing out of control. Wrap with a smart pointer so that buffer will be + // freed once variable goes out of scope. For reference, the code that specifies the layout of the buffer can be + // found: https://github.com/nmslib/nmslib/blob/v2.1.1/similarity_search/include/object.h#L61-L75 + std::unique_ptr objectBuffer + (new char[(similarity::ID_SIZE + similarity::LABEL_SIZE + similarity::DATALENGTH_SIZE + vectorSizeInBytes) + * numVectors]); + char *ptr = objectBuffer.get(); + for (int i = 0; i < numVectors; i++) { + dataset.push_back(new similarity::Object(ptr)); + + memcpy(ptr, &idsCpp[i], similarity::ID_SIZE); + ptr += similarity::ID_SIZE; + memcpy(ptr, &DEFAULT_LABEL, similarity::LABEL_SIZE); + ptr += similarity::LABEL_SIZE; + memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE); + ptr += similarity::DATALENGTH_SIZE; + + memcpy(ptr, &(inputVectors->at(vectorPointer)), vectorSizeInBytes); + ptr += vectorSizeInBytes; + vectorPointer += dim; } - - int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); - if (numIds != numVectors) { - throw std::runtime_error("Number of IDs does not match number of vectors"); + jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + //commons::freeVectorData(vectorsAddressJ); + delete inputVectors; + + std::unique_ptr> index; + index.reset(similarity::MethodFactoryRegistry::Instance().CreateMethod(false, + "hnsw", + spaceTypeCpp, + *(space), + dataset)); + index->CreateIndex(similarity::AnyParams(indexParameters)); + index->SaveIndex(indexPathCpp); + + for (auto &it : dataset) { + delete it; } - - // Read dataset - similarity::ObjectVector dataset; - dataset.reserve(numVectors); - int* idsCpp; - try { - // Read in data set - idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr); - size_t vectorSizeInBytes = dim*sizeof(float); - // vectorPointer needs to be unsigned long long, this will ensure that out of range doesn't happen for this pointer - // when the values of numVectors * dim becomes very large. - // Example: for 10M vectors of 1536 dim vectorPointer max value will be ~15.3B which is already > range of ints. - // keeping it unsigned long long we will never go above the range. - unsigned long long vectorPointer = 0; - - // Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as - // opposed to individually will prevent heap fragmentation. We have observed that allocating individual - // objects causes RSS to rise throughout the lifetime of a process - // (see https://github.com/opensearch-project/k-NN/issues/772 and - // https://github.com/opensearch-project/k-NN/issues/72). This is because, in typical systems, small - // allocations will reside on some kind of heap managed by an allocator. Once freed, the allocator does not - // always return the memory to the OS. If the heap gets fragmented, this will cause the allocator - // to ask for more memory, causing RSS to grow. On large allocations (> 128 kb), most allocators will - // internally use mmap. Once freed, unmap will be called, which will immediately return memory to the OS - // which in turn prevents RSS from growing out of control. Wrap with a smart pointer so that buffer will be - // freed once variable goes out of scope. For reference, the code that specifies the layout of the buffer can be - // found: https://github.com/nmslib/nmslib/blob/v2.1.1/similarity_search/include/object.h#L61-L75 - std::unique_ptr objectBuffer(new char[(similarity::ID_SIZE + similarity::LABEL_SIZE + similarity::DATALENGTH_SIZE + vectorSizeInBytes) * numVectors]); - char* ptr = objectBuffer.get(); - for (int i = 0; i < numVectors; i++) { - dataset.push_back(new similarity::Object(ptr)); - - memcpy(ptr, &idsCpp[i], similarity::ID_SIZE); - ptr += similarity::ID_SIZE; - memcpy(ptr, &DEFAULT_LABEL, similarity::LABEL_SIZE); - ptr += similarity::LABEL_SIZE; - memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE); - ptr += similarity::DATALENGTH_SIZE; - - memcpy(ptr, &(inputVectors->at(vectorPointer)), vectorSizeInBytes); - ptr += vectorSizeInBytes; - vectorPointer += dim; - } - jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); - - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - //commons::freeVectorData(vectorsAddressJ); - delete inputVectors; - - std::unique_ptr> index; - index.reset(similarity::MethodFactoryRegistry::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset)); - index->CreateIndex(similarity::AnyParams(indexParameters)); - index->SaveIndex(indexPathCpp); - - for (auto & it : dataset) { - delete it; - } - } catch (...) { - for (auto & it : dataset) { - delete it; - } - - jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); - throw; + } catch (...) { + for (auto &it : dataset) { + delete it; } + + jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + throw; + } } -jlong knn_jni::nmslib_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ, +jlong knn_jni::nmslib_wrapper::LoadIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jobject parametersJ) { - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); - } + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } - if (parametersJ == nullptr) { - throw std::runtime_error("Parameters cannot be null"); - } + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); - // Get space type for this index - jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); - std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); - spaceTypeCpp = TranslateSpaceType(spaceTypeCpp); + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + spaceTypeCpp = TranslateSpaceType(spaceTypeCpp); - // Parse query params - std::vector queryParams; + // Parse query params + std::vector queryParams; - if(parametersCpp.find("efSearch") != parametersCpp.end()) { - auto efSearch = std::to_string(jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp["efSearch"])); - queryParams.push_back("efSearch=" + efSearch); - } + auto it = parametersCpp.find("efSearch"); + if (it != parametersCpp.end()) { + auto efSearch = std::to_string(jniUtil->ConvertJavaObjectToCppInteger(env, it->second)); + queryParams.push_back("efSearch=" + efSearch); + } - // Load index - knn_jni::nmslib_wrapper::IndexWrapper * indexWrapper; - try { - indexWrapper = new knn_jni::nmslib_wrapper::IndexWrapper(spaceTypeCpp); - indexWrapper->index->LoadIndex(indexPathCpp); - indexWrapper->index->SetQueryTimeParams(similarity::AnyParams(queryParams)); - } catch (...) { - delete indexWrapper; - throw; + // Load index + knn_jni::nmslib_wrapper::IndexWrapper *indexWrapper = nullptr; + try { + indexWrapper = new knn_jni::nmslib_wrapper::IndexWrapper(spaceTypeCpp); + indexWrapper->index->LoadIndex(indexPathCpp); + indexWrapper->index->SetQueryTimeParams(similarity::AnyParams(queryParams)); + } catch (...) { + delete indexWrapper; + throw; + } + + return (jlong) indexWrapper; +} + +jlong knn_jni::nmslib_wrapper::LoadIndexWithStream(knn_jni::JNIUtilInterface *jniUtil, + JNIEnv *env, + jobject readStream, + jobject parametersJ) { + if (readStream == nullptr) { + throw std::runtime_error("Read stream cannot be null"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + spaceTypeCpp = TranslateSpaceType(spaceTypeCpp); + + // Parse query params + std::vector queryParams; + + auto it = parametersCpp.find("efSearch"); + if (it != parametersCpp.end()) { + auto efSearch = std::to_string(jniUtil->ConvertJavaObjectToCppInteger(env, it->second)); + queryParams.push_back("efSearch=" + efSearch); + } + + // Create a mediator locally. + // Note that `indexInput` is `IndexInputWithBuffer` type. + knn_jni::stream::NativeEngineIndexInputMediator mediator{jniUtil, env, readStream}; + + knn_jni::stream::NmslibOpenSearchIOReader ioReader {&mediator}; + + // Load index + knn_jni::nmslib_wrapper::IndexWrapper *indexWrapper = nullptr; + try { + indexWrapper = new knn_jni::nmslib_wrapper::IndexWrapper(spaceTypeCpp); + indexWrapper->index->SetQueryTimeParams(similarity::AnyParams(queryParams)); + + if (auto hnswFloatIndex = dynamic_cast *>(indexWrapper->index.get())) { + hnswFloatIndex->LoadIndexWithStream(ioReader); + } else { + throw std::runtime_error("We only support similarity::Hnsw in NMSLIB."); } + } catch (...) { + delete indexWrapper; + throw; + } - return (jlong) indexWrapper; + return (jlong) indexWrapper; } -jobjectArray knn_jni::nmslib_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, +jobjectArray knn_jni::nmslib_wrapper::QueryIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ) { - if (queryVectorJ == nullptr) { - throw std::runtime_error("Query Vector cannot be null"); - } - - if (indexPointerJ == 0) { - throw std::runtime_error("Invalid pointer to index"); - } + if (queryVectorJ == nullptr) { + throw std::runtime_error("Query Vector cannot be null"); + } - auto *indexWrapper = reinterpret_cast(indexPointerJ); + if (indexPointerJ == 0) { + throw std::runtime_error("Invalid pointer to index"); + } - int dim = jniUtil->GetJavaFloatArrayLength(env, queryVectorJ); + auto *indexWrapper = reinterpret_cast(indexPointerJ); - float* rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); // Have to call release on this + int dim = jniUtil->GetJavaFloatArrayLength(env, queryVectorJ); - std::unique_ptr queryObject; - try { - queryObject.reset(new similarity::Object(-1, -1, dim*sizeof(float), rawQueryvector)); - } catch (...) { - jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - throw; - } + float *rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); // Have to call release on this + std::unique_ptr queryObject; + try { + queryObject.reset(new similarity::Object(-1, -1, dim * sizeof(float), rawQueryvector)); + } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - std::unordered_map methodParams; - if (methodParamsJ != nullptr) { - methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + throw; + } + + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + std::unordered_map methodParams; + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } + + int queryEfSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, -1); + similarity::KNNQuery + *query; // TODO: Replace with smart pointers https://github.com/opensearch-project/k-NN/issues/1785 + std::unique_ptr> neighbors; + try { + if (queryEfSearch == -1) { + query = new similarity::KNNQuery(*(indexWrapper->space), queryObject.get(), kJ); + } else { + query = new similarity::HNSWQuery(*(indexWrapper->space), queryObject.get(), kJ, queryEfSearch); } - int queryEfSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, -1); - similarity::KNNQuery* query; // TODO: Replace with smart pointers https://github.com/opensearch-project/k-NN/issues/1785 - std::unique_ptr> neighbors; - try { - if (queryEfSearch == -1) { - query = new similarity::KNNQuery(*(indexWrapper->space), queryObject.get(), kJ); - } else { - query = new similarity::HNSWQuery(*(indexWrapper->space), queryObject.get(), kJ, queryEfSearch); - } - - indexWrapper->index->Search(query); - neighbors.reset(query->Result()->Clone()); - } catch (...) { - if (query != nullptr) { - delete query; - } - throw; + indexWrapper->index->Search(query); + neighbors.reset(query->Result()->Clone()); + } catch (...) { + if (query != nullptr) { + delete query; } - delete query; - - int resultSize = neighbors->Size(); - jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); - jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); - - jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); - - jobject result; - float distance; - long id; - for(int i = 0; i < resultSize; ++i) { - distance = neighbors->TopDistance(); - id = neighbors->Pop()->id(); - result = jniUtil->NewObject(env, resultClass, allArgs, id, distance); - jniUtil->SetObjectArrayElement(env, results, i, result); - } - - return results; + throw; + } + delete query; + + int resultSize = neighbors->Size(); + jclass resultClass = jniUtil->FindClass(env, "org/opensearch/knn/index/query/KNNQueryResult"); + jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); + + jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); + + jobject result; + float distance; + long id; + for (int i = 0; i < resultSize; ++i) { + distance = neighbors->TopDistance(); + id = neighbors->Pop()->id(); + result = jniUtil->NewObject(env, resultClass, allArgs, id, distance); + jniUtil->SetObjectArrayElement(env, results, i, result); + } + + return results; } void knn_jni::nmslib_wrapper::Free(jlong indexPointerJ) { - auto *indexWrapper = reinterpret_cast(indexPointerJ); - delete indexWrapper; + auto *indexWrapper = reinterpret_cast(indexPointerJ); + delete indexWrapper; } void knn_jni::nmslib_wrapper::InitLibrary() { - similarity::initLibrary(); + similarity::initLibrary(); } -std::string TranslateSpaceType(const std::string& spaceType) { - if (spaceType == knn_jni::L2) { - return spaceType; - } +std::string TranslateSpaceType(const std::string &spaceType) { + if (spaceType == knn_jni::L2) { + return spaceType; + } - if (spaceType == knn_jni::L1) { - return spaceType; - } + if (spaceType == knn_jni::L1) { + return spaceType; + } - if (spaceType == knn_jni::LINF) { - return spaceType; - } + if (spaceType == knn_jni::LINF) { + return spaceType; + } - if (spaceType == knn_jni::COSINESIMIL) { - return spaceType; - } + if (spaceType == knn_jni::COSINESIMIL) { + return spaceType; + } - if (spaceType == knn_jni::INNER_PRODUCT) { - return knn_jni::NEG_DOT_PRODUCT; - } + if (spaceType == knn_jni::INNER_PRODUCT) { + return knn_jni::NEG_DOT_PRODUCT; + } - throw std::runtime_error("Invalid spaceType"); + throw std::runtime_error("Invalid spaceType"); } diff --git a/jni/src/org_opensearch_knn_jni_NmslibService.cpp b/jni/src/org_opensearch_knn_jni_NmslibService.cpp index e265827cd..8e4df2e9c 100644 --- a/jni/src/org_opensearch_knn_jni_NmslibService.cpp +++ b/jni/src/org_opensearch_knn_jni_NmslibService.cpp @@ -12,7 +12,6 @@ #include "org_opensearch_knn_jni_NmslibService.h" #include -#include #include "jni_util.h" #include "nmslib_wrapper.h" @@ -20,71 +19,85 @@ static knn_jni::JNIUtil jniUtil; static const jint KNN_NMSLIB_JNI_VERSION = JNI_VERSION_1_1; -jint JNI_OnLoad(JavaVM* vm, void* reserved) { - JNIEnv* env; - if (vm->GetEnv((void**)&env, KNN_NMSLIB_JNI_VERSION) != JNI_OK) { - return JNI_ERR; - } +jint JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env; + if (vm->GetEnv((void **) &env, KNN_NMSLIB_JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } - jniUtil.Initialize(env); + jniUtil.Initialize(env); - return KNN_NMSLIB_JNI_VERSION; + return KNN_NMSLIB_JNI_VERSION; } void JNI_OnUnload(JavaVM *vm, void *reserved) { - JNIEnv* env; - vm->GetEnv((void**)&env, KNN_NMSLIB_JNI_VERSION); - jniUtil.Uninitialize(env); + JNIEnv *env; + vm->GetEnv((void **) &env, KNN_NMSLIB_JNI_VERSION); + jniUtil.Uninitialize(env); } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jobject parametersJ) -{ - try { - knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNIEnv *env, + jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jobject parametersJ) { + try { + knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } } -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex(JNIEnv * env, jclass cls, - jstring indexPathJ, jobject parametersJ) -{ - try { - return knn_jni::nmslib_wrapper::LoadIndex(&jniUtil, env, indexPathJ, parametersJ); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } - return NULL; +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex(JNIEnv *env, jclass cls, + jstring indexPathJ, jobject parametersJ) { + try { + return knn_jni::nmslib_wrapper::LoadIndex(&jniUtil, env, indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return NULL; } -JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_NmslibService_queryIndex(JNIEnv * env, jclass cls, +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndexWithStream(JNIEnv *env, + jclass cls, + jobject readStream, + jobject parametersJ) { + try { + return knn_jni::nmslib_wrapper::LoadIndexWithStream(&jniUtil, env, readStream, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return NULL; +} + +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_NmslibService_queryIndex(JNIEnv *env, + jclass cls, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ) -{ - try { - return knn_jni::nmslib_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } - return nullptr; + jfloatArray queryVectorJ, + jint kJ, + jobject methodParamsJ) { + try { + return knn_jni::nmslib_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) -{ - try { - return knn_jni::nmslib_wrapper::Free(indexPointerJ); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_free(JNIEnv *env, jclass cls, jlong indexPointerJ) { + try { + return knn_jni::nmslib_wrapper::Free(indexPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_initLibrary(JNIEnv * env, jclass cls) -{ - try { - knn_jni::nmslib_wrapper::InitLibrary(); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_initLibrary(JNIEnv *env, jclass cls) { + try { + knn_jni::nmslib_wrapper::InitLibrary(); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } } diff --git a/jni/tests/faiss_stream_support_test.cpp b/jni/tests/faiss_stream_support_test.cpp index 4045985bb..94a9b3991 100644 --- a/jni/tests/faiss_stream_support_test.cpp +++ b/jni/tests/faiss_stream_support_test.cpp @@ -8,83 +8,49 @@ // GitHub history for details. #include "faiss_stream_support.h" -#include +#include "native_stream_support_util.h" #include "test_util.h" + +#include #include -#include #include #include +using ::testing::_; using ::testing::Return; using knn_jni::stream::FaissOpenSearchIOReader; using knn_jni::stream::NativeEngineIndexInputMediator; using test_util::MockJNIUtil; - -// Mocking IndexInputWithBuffer. -struct JavaIndexInputMock { - JavaIndexInputMock(std::string _readTargetBytes, int32_t _bufSize) - : readTargetBytes(std::move(_readTargetBytes)), - nextReadIdx(), - buffer(_bufSize) { - } - - // This method is simulating `copyBytes` in IndexInputWithBuffer. - int32_t simulateCopyReads(int64_t readBytes) { - readBytes = std::min(readBytes, (int64_t) buffer.size()); - readBytes = std::min(readBytes, (int64_t) (readTargetBytes.size() - nextReadIdx)); - std::memcpy(buffer.data(), readTargetBytes.data() + nextReadIdx, readBytes); - nextReadIdx += readBytes; - return (int32_t) readBytes; - } - - static std::string makeRandomBytes(int32_t bytesSize) { - // Define the list of possible characters - static const string CHARACTERS - = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuv" - "wxyz0123456789"; - - // Create a random number generator - std::random_device rd; - std::mt19937 generator(rd()); - - // Create a distribution to uniformly select from all characters - std::uniform_int_distribution<> distribution( - 0, CHARACTERS.size() - 1); - - // Pre-allocate the string with the desired length - std::string randomString(bytesSize, '\0'); - - // Use generate_n with a back_inserter iterator - std::generate_n(randomString.begin(), bytesSize, [&]() { - return CHARACTERS[distribution(generator)]; - }); - - return randomString; - } - - std::string readTargetBytes; - int64_t nextReadIdx; - std::vector buffer; -}; // struct JavaIndexInputMock +using test_util::JavaIndexInputMock; +using ::testing::NiceMock; +using ::testing::Return; void setUpMockJNIUtil(JavaIndexInputMock &javaIndexInputMock, MockJNIUtil &mockJni) { // Set up mocking values + mocking behavior in a method. - ON_CALL(mockJni, FindClassFromJNIEnv).WillByDefault(Return((jclass) 1)); - ON_CALL(mockJni, GetMethodID).WillByDefault(Return((jmethodID) 1)); - ON_CALL(mockJni, GetFieldID).WillByDefault(Return((jfieldID) 1)); - ON_CALL(mockJni, GetObjectField).WillByDefault(Return((jobject) 1)); - ON_CALL(mockJni, CallIntMethodLong).WillByDefault([&javaIndexInputMock](JNIEnv *env, - jobject obj, - jmethodID methodID, - int64_t longArg) { - return javaIndexInputMock.simulateCopyReads(longArg); - }); - ON_CALL(mockJni, GetPrimitiveArrayCritical).WillByDefault([&javaIndexInputMock](JNIEnv *env, - jarray array, - jboolean *isCopy) { - return (jbyte *) javaIndexInputMock.buffer.data(); - }); - ON_CALL(mockJni, ReleasePrimitiveArrayCritical).WillByDefault(Return()); + EXPECT_CALL(mockJni, CallNonvirtualIntMethodA(_, _, _, _, _)) + .WillRepeatedly([&javaIndexInputMock](JNIEnv *env, + jobject obj, + jclass clazz, + jmethodID methodID, + jvalue* args) { + return javaIndexInputMock.simulateCopyReads(args[0].j); + }); + EXPECT_CALL(mockJni, CallNonvirtualLongMethodA(_, _, _, _, _)) + .WillRepeatedly([&javaIndexInputMock](JNIEnv *env, + jobject obj, + jclass clazz, + jmethodID methodID, + jvalue* args) { + return javaIndexInputMock.remainingBytes(); + }); + EXPECT_CALL(mockJni, GetPrimitiveArrayCritical(_, _, _)) + .WillRepeatedly([&javaIndexInputMock](JNIEnv *env, + jarray array, + jboolean *isCopy) { + return (jbyte *) javaIndexInputMock.buffer.data(); + }); + EXPECT_CALL(mockJni, ReleasePrimitiveArrayCritical(_, _, _, _)) + .WillRepeatedly(Return()); } TEST(FaissStreamSupportTest, NativeEngineIndexInputMediatorCopyWhenEmpty) { @@ -110,7 +76,7 @@ TEST(FaissStreamSupportTest, NativeEngineIndexInputMediatorCopyWhenEmpty) { TEST(FaissStreamSupportTest, FaissOpenSearchIOReaderCopy) { for (auto contentSize : std::vector{0, 2222, 7777, 1024, 77, 1}) { // Set up mockings - MockJNIUtil mockJni; + NiceMock mockJni; JavaIndexInputMock javaIndexInputMock{ JavaIndexInputMock::makeRandomBytes(contentSize), 1024}; setUpMockJNIUtil(javaIndexInputMock, mockJni); diff --git a/jni/tests/native_stream_support_util.h b/jni/tests/native_stream_support_util.h new file mode 100644 index 000000000..e33f3beb4 --- /dev/null +++ b/jni/tests/native_stream_support_util.h @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#ifndef KNNPLUGIN_JNI_TESTS_NATIVE_STREAM_SUPPORT_UTIL_H_ +#define KNNPLUGIN_JNI_TESTS_NATIVE_STREAM_SUPPORT_UTIL_H_ + +#include "test_util.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace test_util { + + +// Mocking IndexInputWithBuffer. +struct JavaIndexInputMock { + JavaIndexInputMock(std::string _readTargetBytes, int32_t _bufSize) + : readTargetBytes(std::move(_readTargetBytes)), + nextReadIdx(), + buffer(_bufSize) { + } + + // This method is simulating `copyBytes` in IndexInputWithBuffer. + int32_t simulateCopyReads(int64_t readBytes) { + readBytes = std::min(readBytes, (int64_t) buffer.size()); + readBytes = std::min(readBytes, (int64_t) (readTargetBytes.size() - nextReadIdx)); + std::memcpy(buffer.data(), readTargetBytes.data() + nextReadIdx, readBytes); + nextReadIdx += readBytes; + return (int32_t) readBytes; + } + + int64_t remainingBytes() { + return readTargetBytes.size() - nextReadIdx; + } + + static std::string makeRandomBytes(int32_t bytesSize) { + // Define the list of possible characters + static const string CHARACTERS + = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuv" + "wxyz0123456789"; + + // Create a random number generator + std::random_device rd; + std::mt19937 generator(rd()); + + // Create a distribution to uniformly select from all characters + std::uniform_int_distribution<> distribution( + 0, CHARACTERS.size() - 1); + + // Pre-allocate the string with the desired length + std::string randomString(bytesSize, '\0'); + + // Use generate_n with a back_inserter iterator + std::generate_n(randomString.begin(), bytesSize, [&]() { + return CHARACTERS[distribution(generator)]; + }); + + return randomString; + } + + std::string readTargetBytes; + int64_t nextReadIdx; + std::vector buffer; +}; // struct JavaIndexInputMock + + + +struct JavaFileIndexInputMock { + JavaFileIndexInputMock(std::ifstream &_file_input, int32_t _buf_size) + : file_input(_file_input), + buffer(_buf_size) { + } + + int64_t remainingBytes() { + std::streampos currentPos = file_input.tellg(); + file_input.seekg(0, std::ios::end); + std::streamsize fileSize = file_input.tellg(); + file_input.seekg(currentPos); + return fileSize - currentPos; + } + + int32_t copyBytes(int64_t read_size) { + const auto copy_size = std::min((int64_t) buffer.size(), read_size); + file_input.read(buffer.data(), copy_size); + return (int32_t) copy_size; + } + + std::ifstream &file_input; + std::vector buffer; +}; // class JavaFileIndexInputMock + + +} // namespace test_util + +#endif //KNNPLUGIN_JNI_TESTS_NATIVE_STREAM_SUPPORT_UTIL_H_ diff --git a/jni/tests/nmslib_stream_support_test.cpp b/jni/tests/nmslib_stream_support_test.cpp new file mode 100644 index 000000000..e0e7a2d08 --- /dev/null +++ b/jni/tests/nmslib_stream_support_test.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#include "nmslib_wrapper.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" +#include "test_util.h" +#include "native_stream_support_util.h" + +using ::testing::_; +using ::testing::NiceMock; +using ::testing::Return; +using ::test_util::MockJNIUtil; +using ::test_util::JavaFileIndexInputMock; + +void setUpJavaFileInputMocking(JavaFileIndexInputMock &java_index_input, MockJNIUtil &mockJni) { + // Set up mocking values + mocking behavior in a method. + EXPECT_CALL(mockJni, CallNonvirtualIntMethodA(_, _, _, _, _)) + .WillRepeatedly([&java_index_input](JNIEnv *env, + jobject obj, + jclass clazz, + jmethodID methodID, + jvalue *args) { + return java_index_input.copyBytes(args[0].j); + }); + EXPECT_CALL(mockJni, CallNonvirtualLongMethodA(_, _, _, _, _)) + .WillRepeatedly([&java_index_input](JNIEnv *env, + jobject obj, + jclass clazz, + jmethodID methodID, + jvalue *args) { + return java_index_input.remainingBytes(); + }); + EXPECT_CALL(mockJni, GetPrimitiveArrayCritical(_, _, _)).WillRepeatedly([&java_index_input](JNIEnv *env, + jarray array, + jboolean *isCopy) { + return (jbyte *) java_index_input.buffer.data(); + }); + EXPECT_CALL(mockJni, ReleasePrimitiveArrayCritical(_, _, _, _)).WillRepeatedly(Return()); +} + +TEST(NmslibStreamLoadingTest, BasicAssertions) { + // Initialize nmslib + similarity::initLibrary(); + + // Define index data + int numIds = 100; + std::vector ids; + auto vectors = new std::vector(); + int dim = 2; + vectors->reserve(dim * numIds); + for (int i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + std::string spaceType = knn_jni::L2; + std::string indexPath = test_util::RandomString( + 10, "/tmp/", ".nmslib"); + + std::unordered_map parametersMap; + int efConstruction = 512; + int m = 96; + + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + parametersMap[knn_jni::EF_CONSTRUCTION] = (jobject) &efConstruction; + parametersMap[knn_jni::M] = (jobject) &m; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( + jniEnv, reinterpret_cast(vectors))) + .WillRepeatedly(Return(vectors->size())); + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength(jniEnv, reinterpret_cast(&ids))) + .WillRepeatedly(Return(ids.size())); + + EXPECT_CALL(mockJNIUtil, + ConvertJavaMapToCppMap(jniEnv, reinterpret_cast(¶metersMap))) + .WillRepeatedly(Return(parametersMap)); + + // Create the index + knn_jni::nmslib_wrapper::CreateIndex( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) vectors, dim, (jstring) &indexPath, + (jobject) ¶metersMap); + + // Create Java index input mock. + std::ifstream file_input{indexPath, std::ios::binary}; + const int32_t buffer_size = 128; + JavaFileIndexInputMock java_file_index_input_mock{file_input, buffer_size}; + setUpJavaFileInputMocking(java_file_index_input_mock, mockJNIUtil); + + // Make sure index can be loaded + jlong index = knn_jni::nmslib_wrapper::LoadIndexWithStream( + &mockJNIUtil, jniEnv, + (jobject) (&java_file_index_input_mock), + (jobject) (¶metersMap)); + + knn_jni::nmslib_wrapper::Free(index); + + // Clean up + std::remove(indexPath.c_str()); +} diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 286000c08..a6b39aa41 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -111,7 +111,8 @@ namespace test_util { MOCK_METHOD(jclass, FindClassFromJNIEnv, (JNIEnv * env, const char *name)); MOCK_METHOD(jmethodID, GetMethodID, (JNIEnv * env, jclass clazz, const char *name, const char *sig)); MOCK_METHOD(jfieldID, GetFieldID, (JNIEnv * env, jclass clazz, const char *name, const char *sig)); - MOCK_METHOD(jint, CallIntMethodLong, (JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg)); + MOCK_METHOD(jint, CallNonvirtualIntMethodA, (JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args)); + MOCK_METHOD(jlong, CallNonvirtualLongMethodA, (JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args)); MOCK_METHOD(void *, GetPrimitiveArrayCritical, (JNIEnv * env, jarray array, jboolean *isCopy)); MOCK_METHOD(void, ReleasePrimitiveArrayCritical, (JNIEnv * env, jarray array, void *carray, jint mode)); }; diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 8adf35447..360c827f9 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -187,8 +187,8 @@ class IndexAllocation implements NativeMemoryAllocation { protected void closeInternal() { Runnable onClose = () -> { + writeLock(); try { - writeLock(); cleanup(); } finally { writeUnlock(); @@ -328,8 +328,8 @@ public TrainingDataAllocation(ExecutorService executor, long memoryAddress, int @Override public void close() { executor.execute(() -> { + writeLock(); try { - writeLock(); cleanup(); } finally { writeUnlock(); diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index 51158d00c..5daa1e047 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -91,36 +91,11 @@ public void onFileDeleted(Path indexFilePath) { }; } - private NativeMemoryAllocation.IndexAllocation loadWithAbsoluteIndexPath( - NativeMemoryEntryContext.IndexEntryContext indexEntryContext - ) throws IOException { - Path indexPath = Paths.get(indexEntryContext.getKey()); - FileWatcher fileWatcher = new FileWatcher(indexPath); - fileWatcher.addListener(indexFileOnDeleteListener); - fileWatcher.init(); - - KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(indexPath.toString()); - long indexAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine); - return createIndexAllocation( - indexEntryContext, - knnEngine, - indexAddress, - fileWatcher, - indexEntryContext.calculateSizeInKB(), - indexPath - ); - } - @Override public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.IndexEntryContext indexEntryContext) throws IOException { final Path absoluteIndexPath = Paths.get(indexEntryContext.getKey()); final KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(absoluteIndexPath.toString()); - if (knnEngine != KNNEngine.FAISS) { - // We will support other non-FAISS native engines (ex: NMSLIB) soon. - return loadWithAbsoluteIndexPath(indexEntryContext); - } - final FileWatcher fileWatcher = new FileWatcher(absoluteIndexPath); fileWatcher.addListener(indexFileOnDeleteListener); fileWatcher.init(); @@ -182,7 +157,7 @@ class TrainingLoadStrategy NativeMemoryLoadStrategy, Closeable { - private static TrainingLoadStrategy INSTANCE; + private static volatile TrainingLoadStrategy INSTANCE; private final ExecutorService executor; private VectorReader vectorReader; diff --git a/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java b/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java index 273a4deac..0b1c934e5 100644 --- a/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java +++ b/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java @@ -18,11 +18,13 @@ */ public class IndexInputWithBuffer { private IndexInput indexInput; - // 4K buffer. - private byte[] buffer = new byte[4 * 1024]; + private long contentLength; + // 64K buffer. + private byte[] buffer = new byte[64 * 1024]; public IndexInputWithBuffer(@NonNull IndexInput indexInput) { this.indexInput = indexInput; + this.contentLength = indexInput.length(); } /** @@ -39,6 +41,10 @@ private int copyBytes(long nbytes) throws IOException { return readBytes; } + private long remainingBytes() { + return contentLength - indexInput.getFilePointer(); + } + @Override public String toString() { return "{indexInput=" + indexInput + ", len(buffer)=" + buffer.length + "}"; diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 448241f9c..a0daf65a7 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -227,6 +227,8 @@ public static long loadIndex(IndexInputWithBuffer readStream, Map parameters); + /** + * Load an index into memory through the provided read stream wrapping Lucene's IndexInput. + * + * @param readStream Read stream wrapping Lucene's IndexInput. + * @param parameters Parameters to be used when loading index + * @return Pointer to location in memory the index resides in + */ + public static native long loadIndexWithStream(IndexInputWithBuffer readStream, Map parameters); + /** * Query an index * diff --git a/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java b/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java index f7f68eda1..935a2e22c 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java @@ -48,7 +48,7 @@ private void tripCb() throws Exception { createKnnIndex(indexName2, settings, createKnnIndexMapping(FIELD_NAME, 2)); Float[] vector = { 1.3f, 2.2f }; - int docsInIndex = 5; // through testing, 7 is minimum number of docs to trip circuit breaker at 1kb + int docsInIndex = 7; // through testing, 7 is minimum number of docs to trip circuit breaker at 1kb for (int i = 0; i < docsInIndex; i++) { addKnnDoc(indexName1, Integer.toString(i), FIELD_NAME, vector); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 8566b0223..f6d118092 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -749,6 +749,31 @@ public void testLoadIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); } + public void testLoadIndex_nmslib_valid_with_stream() throws IOException { + Path tmpFile = createTempFile(); + + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(tmpFile.toFile().length() > 0); + + try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { + try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + long pointer = JNIService.loadIndex( + new IndexInputWithBuffer(indexInput), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertNotEquals(0, pointer); + } + } + } + public void testLoadIndex_faiss_invalid_fileDoesNotExist() { expectThrows(Exception.class, () -> JNIService.loadIndex("invalid", Collections.emptyMap(), KNNEngine.FAISS)); }