diff --git a/jni/cmake/init-faiss.cmake-e b/jni/cmake/init-faiss.cmake-e new file mode 100644 index 000000000..bef93eda0 --- /dev/null +++ b/jni/cmake/init-faiss.cmake-e @@ -0,0 +1,69 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# + +# Check if faiss exists +find_path(FAISS_REPO_DIR NAMES faiss PATHS ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss NO_DEFAULT_PATH) + +# If not, pull the updated submodule +if (NOT EXISTS ${FAISS_REPO_DIR}) + message(STATUS "Could not find faiss. Pulling updated submodule.") + execute_process(COMMAND git submodule update --init -- external/faiss WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +endif () + +# Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu. +find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH) + +# If it exists, apply patches +if (EXISTS ${PATCH_FILE}) + message(STATUS "Applying custom patches.") + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + if(RESULT_CODE) + message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}") + endif() +endif() + +if (${CMAKE_SYSTEM_NAME} STREQUAL Darwin) + if(CMAKE_C_COMPILER_ID MATCHES "Clang\$") + set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp") + set(OpenMP_C_LIB_NAMES "omp") + set(OpenMP_omp_LIBRARY /usr/local/opt/libomp/lib/libomp.dylib) + endif() + + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang\$") + set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include") + set(OpenMP_CXX_LIB_NAMES "omp") + set(OpenMP_omp_LIBRARY /usr/local/opt/libomp/lib/libomp.dylib) + endif() +endif() + +find_package(ZLIB REQUIRED) + +# Statically link BLAS - ensure this is before we find the blas package so we dont dynamically link +set(BLA_STATIC ON) +find_package(BLAS REQUIRED) +enable_language(Fortran) +find_package(LAPACK REQUIRED) + +# Set relevant properties +set(BUILD_TESTING OFF) # Avoid building faiss tests +set(FAISS_ENABLE_GPU OFF) +set(FAISS_ENABLE_PYTHON OFF) + +if(NOT DEFINED SIMD_ENABLED) + set(SIMD_ENABLED true) # set default value as true if the argument is not set +endif() + +if(${CMAKE_SYSTEM_NAME} STREQUAL Windows OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64" OR NOT ${SIMD_ENABLED}) + set(FAISS_OPT_LEVEL generic) # Keep optimization level as generic on Windows OS as it is not supported due to MINGW64 compiler issue. Also, on aarch64 avx2 is not supported. + set(TARGET_LINK_FAISS_LIB faiss) +else() + set(FAISS_OPT_LEVEL avx2) # Keep optimization level as avx2 to improve performance on Linux and Mac. + set(TARGET_LINK_FAISS_LIB faiss_avx2) + string(PREPEND LIB_EXT "_avx2") # Prepend "_avx2" to lib extension to create the library as "libopensearchknn_faiss_avx2.so" on linux and "libopensearchknn_faiss_avx2.jnilib" on mac +endif() + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9..9de374950 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -31,8 +31,50 @@ namespace faiss_wrapper { class IndexService { public: IndexService(std::unique_ptr faissMethods); - //TODO Remove dependency on JNIUtilInterface and JNIEnv - //TODO Reduce the number of parameters + /** + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return pointer to the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters); + /** + * Add vectors to index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress, std::unordered_map parameters); + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data + * @param indexPath path to write index + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void writeIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int threadCount, std::string indexPath, jlong idMapAddress, std::unordered_map parameters); + // TODO Remove dependency on JNIUtilInterface and JNIEnv + // TODO Reduce the number of parameters /** * Create index @@ -75,6 +117,48 @@ class BinaryIndexService : public IndexService { //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters BinaryIndexService(std::unique_ptr faissMethods); + /** + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return pointer to the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + /** + * Add vectors to index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress, std::unordered_map parameters) override; + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data + * @param indexPath path to write index + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void writeIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int threadCount, std::string indexPath, jlong idMapAddress, std::unordered_map parameters) override; /** * Create binary index * diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 5ad0dedc4..7ac0fd945 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -18,6 +18,12 @@ namespace knn_jni { namespace faiss_wrapper { + jlong InitIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, IndexService *indexService); + + void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong index_ptr, jobject parametersJ, IndexService *indexService); + + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jlong index_ptr, jobject parametersJ, IndexService *indexService); + // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 025fb12e8..c2845dad3 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,6 +18,54 @@ #ifdef __cplusplus extern "C" { #endif +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/.idea/modules.xml b/jni/src/.idea/modules.xml new file mode 100644 index 000000000..f669a0e59 --- /dev/null +++ b/jni/src/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/jni/src/.idea/vcs.xml b/jni/src/.idea/vcs.xml new file mode 100644 index 000000000..b2bdec2d7 --- /dev/null +++ b/jni/src/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/jni/src/.idea/workspace.xml b/jni/src/.idea/workspace.xml new file mode 100644 index 000000000..45289f4a4 --- /dev/null +++ b/jni/src/.idea/workspace.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + 1721418022030 + + + + \ No newline at end of file diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af..22e0d5b87 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -57,6 +57,110 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} +jlong IndexService::initIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Removed the check for number of vectors + // Don't use unique_ptr here since we want to access the index in the future. + faiss::Index * indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + faiss::IndexIDMap * idMap(faissMethods->indexIdMap(indexWriter)); + + faiss::IndexHNSW * hnsw = dynamic_cast(idMap->index); + + if(hnsw != NULL) { + faiss::IndexFlat * storage = dynamic_cast(hnsw->storage); + if(storage != NULL) { + storage->codes.reserve(dim * numVectors); + } + } + + return (jlong)idMap; +} + +void IndexService::insertToIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); + + // Add vectors + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} + +void IndexService::writeIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int threadCount, + std::string indexPath, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); + + // Write the index to disk + faissMethods->writeIndex(idMap, indexPath.c_str()); + + // Free the memory used by the index + delete idMap; + delete idMap->index; +} + void IndexService::createIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -108,6 +212,110 @@ void IndexService::createIndex( BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} +jlong BinaryIndexService::initIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Removed the check for number of vectors + // Don't use unique_ptr here since we want to access the index in the future. + faiss::IndexBinary * indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + faiss::IndexBinaryIDMap * idMap(faissMethods->indexBinaryIdMap(indexWriter)); + + faiss::IndexBinaryHNSW * hnsw = dynamic_cast(idMap->index); + + if(hnsw != NULL) { + faiss::IndexBinaryFlat * storage = dynamic_cast(hnsw->storage); + if(storage != NULL) { + storage->xb.reserve(dim / 8 * numVectors); + } + } + + return (jlong)idMap; +} + +void BinaryIndexService::insertToIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexBinaryIDMap * idMap = reinterpret_cast (idMapAddress); + + // Add vectors + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} + +void BinaryIndexService::writeIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int threadCount, + std::string indexPath, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexBinaryIDMap * idMap = reinterpret_cast (idMapAddress); + + // Write the index to disk + faissMethods->writeIndexBinary(idMap, indexPath.c_str()); + + // Free the memory used by the index + delete idMap; + delete idMap->index; +} + void BinaryIndexService::createIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 1d4437414..2644a6bfe 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -88,6 +88,169 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); +jlong knn_jni::faiss_wrapper::InitIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, IndexService* indexService) { + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Dimension + int dim = (int)dimJ; + + // Number of docs + int docs = (int)numDocs; + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); +} + +void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong index_ptr, jobject parametersJ, IndexService* indexService) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Dimension + int dim = (int)dimJ; + + // Number of vectors + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Vectors address + int64_t vectorsAddress = (int64_t)vectorsAddressJ; + + // Ids + auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + indexService->insertToIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, index_ptr, subParametersCpp); +} + +void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, + jstring indexPathJ, jlong index_ptr, jobject parametersJ, IndexService* indexService) { + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Index path + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + indexService->writeIndex(jniUtil, env, metric, indexDescriptionCpp, threadCount, indexPathCpp, index_ptr, subParametersCpp); +} + void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ, IndexService* indexService) { if (idsJ == nullptr) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 2394e2951..de4261688 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -39,6 +39,96 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { jniUtil.Uninitialize(env); } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, parametersJ, &indexService); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete reinterpret_cast*>(vectorsAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, parametersJ, &binaryIndexService); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete reinterpret_cast*>(vectorsAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, parametersJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index ea5cb5e3b..40d9d65ca 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -115,6 +115,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); + if (KNNCodecUtil.getTotalLiveDocsCount(values) == 0) { + return; + } + // Increment counter for number of graph index requests + KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + if (isMerge) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + } final KNNEngine knnEngine = getKNNEngine(field); final String engineFileName = buildEngineFileName( state.segmentInfo.name, @@ -127,142 +135,258 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, engineFileName ).toString(); - // Determine if we are creating an index from a model or from scratch - NativeIndexCreator indexCreator; - KNNCodecUtil.Pair pair; - Map fieldAttributes = field.attributes(); - - if (fieldAttributes.containsKey(MODEL_ID)) { - String modelId = fieldAttributes.get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); - } else { - // get vector data type from field attributes or provide default value - VectorDataType vectorDataType = VectorDataType.get( - fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) - ); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } - - // Skip index creation if no vectors or docs in segment - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); - return; - } - - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); - - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - recordMergeStats(pair.docs.length, arraySize); - } - - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + state.directory.createOutput(engineFileName, state.context).close(); + boolean fromScratch = !field.attributes().containsKey(MODEL_ID); + boolean iterative = fromScratch && KNNEngine.FAISS == knnEngine; + createKNNIndex(field, values, knnEngine, indexPath, fromScratch, iterative, isMerge); if (isRefresh) { recordRefreshStats(); } - - // Ensure engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper - state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); writeFooter(indexPath, engineFileName); } + private void currentMergeStats(int length, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(length); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); + KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length); + KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); + } + private void recordMergeStats(int length, long arraySize) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); - KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); - KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length); - KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); } private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + private Map genParameters(boolean fromScratch, FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { Map parameters = new HashMap<>(); + ; + if (fromScratch) { + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put( + KNNConstants.SPACE_TYPE, + fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()) + ); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine + && VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + } + } + // Used to determine how many threads to use when indexing parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + return parameters; + } + + private long initIndexFromScratch(long size, int dim, KNNEngine knnEngine, Map parameters) throws IOException { + // Pass the path for the nms library to save the file + return AccessController.doPrivileged((PrivilegedAction) () -> { + return JNIService.initIndexFromScratch(size, dim, parameters, knnEngine); + }); + } + + private void insertToIndex(KNNCodecUtil.VectorBatch pair, KNNEngine knnEngine, long indexAddress, Map parameters) + throws IOException { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), parameters, indexAddress, knnEngine); + return null; + }); + } + private void writeIndex(long indexAddress, String indexPath, KNNEngine knnEngine, Map parameters) throws IOException { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexPath, indexAddress, knnEngine, parameters); + return null; + }); + } + + private void createKNNIndexFromTemplate( + FieldInfo field, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + Map parameters, + boolean isMerge + ) throws IOException { + String modelId = field.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + byte[] modelBlob = model.getModelBlob(); IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, getVectorTransfer(vectorDataType), false); + + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(values); + + if (numDocs == 0) { + return; + } + + long arraySize = calculateArraySize(numDocs, batch.getDimension(), batch.serializationMode); + + if (isMerge) { + currentMergeStats(numDocs, arraySize); + } AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( - pair.docs, - pair.getVectorAddress(), - pair.getDimension(), + batch.docs, + batch.getVectorAddress(), + batch.getDimension(), indexPath, - model.getModelBlob(), + modelBlob, parameters, knnEngine ); return null; }); + + if (isMerge) { + recordMergeStats(numDocs, arraySize); + } } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { - Map parameters = new HashMap<>(); + private void createKNNIndexFromScratch( + FieldInfo fieldInfo, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + Map parameters, + boolean isMerge + ) throws IOException { Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + VectorDataType vectorDataType = VectorDataType.get( + fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + ); + VectorTransfer transfer = getVectorTransfer(vectorDataType); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, transfer, false); - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(values); - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); + if (numDocs == 0) { + return; } - // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine - && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); - } + long arraySize = calculateArraySize(numDocs, batch.getDimension(), batch.serializationMode); - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + if (isMerge) { + currentMergeStats(numDocs, arraySize); + } - // Pass the path for the nms library to save the file AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); + JNIService.createIndex(batch.docs, batch.getVectorAddress(), batch.getDimension(), indexPath, parameters, knnEngine); return null; }); + + if (isMerge) { + recordMergeStats(numDocs, arraySize); + } + } + + private void createKNNIndexFromScratchIteratively( + FieldInfo fieldInfo, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + Map parameters, + boolean isMerge + ) throws IOException { + Map fieldAttributes = fieldInfo.attributes(); + VectorDataType vectorDataType = VectorDataType.get( + fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + ); + VectorTransfer transfer = getVectorTransfer(vectorDataType); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, transfer, true); + + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(values); + + if (numDocs == 0) { + return; + } + + long arraySize = calculateArraySize(numDocs, batch.getDimension(), batch.serializationMode); + + if (isMerge) { + currentMergeStats(numDocs, arraySize); + } + + long indexAddress = initIndexFromScratch(numDocs, batch.getDimension(), knnEngine, parameters); + for (; !batch.finished; batch = KNNCodecUtil.getVectorBatch(values, transfer, true)) { + insertToIndex(batch, knnEngine, indexAddress, parameters); + } + insertToIndex(batch, knnEngine, indexAddress, parameters); + writeIndex(indexAddress, indexPath, knnEngine, parameters); + if (isMerge) { + recordMergeStats(numDocs, arraySize); + } + } + + private void createKNNIndex( + FieldInfo fieldInfo, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + boolean fromScratch, + boolean iterative, + boolean isMerge + ) throws IOException { + Map parameters = genParameters(fromScratch, fieldInfo, knnEngine); + if (fromScratch && iterative) { + createKNNIndexFromScratchIteratively(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } else if (fromScratch) { + createKNNIndexFromScratch(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } else { + createKNNIndexFromTemplate(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } + /* + if(fromScratch) { + createKNNIndexFromScratch(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } else { + createKNNIndexFromTemplate(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } + */ } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java index 5e9831708..f193aa05b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java @@ -34,11 +34,14 @@ public void init(final long totalLiveDocs) { public void transfer(final BytesRef bytesRef) { dimension = bytesRef.length * 8; if (vectorsPerTransfer == Integer.MIN_VALUE) { - vectorsPerTransfer = (bytesRef.length * totalLiveDocs) / vectorsStreamingMemoryLimit; + vectorsPerTransfer = vectorsStreamingMemoryLimit / bytesRef.length; + if (totalLiveDocs > 0) { + vectorsPerTransfer = Math.min(vectorsPerTransfer, totalLiveDocs); + } // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer // Doing this will reduce 1 extra trip to JNI layer. if (vectorsPerTransfer == 0) { - vectorsPerTransfer = totalLiveDocs; + vectorsPerTransfer = 1; } } @@ -60,7 +63,15 @@ public SerializationMode getSerializationMode(final BytesRef bytesRef) { private void transfer() { int lengthOfVector = dimension / 8; - vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * lengthOfVector); + if (totalLiveDocs != 0) { + vectorAddress = JNICommons.storeByteVectorData( + vectorAddress, + vectorList.toArray(new byte[][] {}), + totalLiveDocs * lengthOfVector + ); + } else { + vectorAddress = JNICommons.storeByteVectorData(0, vectorList.toArray(new byte[][] {}), vectorList.size() * lengthOfVector); + } vectorList.clear(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java index af6d9490e..7c5e93609 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java @@ -35,14 +35,19 @@ public void init(final long totalLiveDocs) { public void transfer(final BytesRef bytesRef) { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(bytesRef); final float[] vector = vectorSerializer.byteToFloatArray(bytesRef); + // System.out.println("Vector: " + vector.length); dimension = vector.length; if (vectorsPerTransfer == Integer.MIN_VALUE) { - vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + vectorsPerTransfer = vectorsStreamingMemoryLimit / bytesRef.length; + if (totalLiveDocs > 0) { + vectorsPerTransfer = Math.min(vectorsPerTransfer, totalLiveDocs); + } + // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer // Doing this will reduce 1 extra trip to JNI layer. if (vectorsPerTransfer == 0) { - vectorsPerTransfer = totalLiveDocs; + vectorsPerTransfer = 1; } } @@ -63,7 +68,11 @@ public SerializationMode getSerializationMode(final BytesRef bytesRef) { } private void transfer() { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + if (totalLiveDocs != 0) { + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + } else { + vectorAddress = JNICommons.storeVectorData(0, vectorList.toArray(new float[][] {}), vectorList.size() * dimension); + } vectorList.clear(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 04aeb337f..afde1fac6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -29,7 +29,7 @@ public class KNNCodecUtil { public static final int JAVA_ROUNDING_NUMBER = 8; @AllArgsConstructor - public static final class Pair { + public static final class VectorBatch { public int[] docs; @Getter @Setter @@ -38,32 +38,50 @@ public static final class Pair { @Setter private int dimension; public SerializationMode serializationMode; + public boolean finished; } /** * Extract docIds and vectors from binary doc values. - * - * @param values Binary doc values - * @param vectorTransfer Utility to make transfer - * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors - * @throws IOException thrown when unable to get binary of vectors - */ - public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException { + * + * @param values Binary doc values + * @param vectorTransfer Utility to make transfer + * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors + * @throws IOException thrown when unable to get binary of vectors + */ + public static KNNCodecUtil.VectorBatch getVectorBatch( + final BinaryDocValues values, + final VectorTransfer vectorTransfer, + boolean iterative + ) throws IOException { List docIdList = new ArrayList<>(); SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - vectorTransfer.init(getTotalLiveDocsCount(values)); - for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + if (iterative) { + vectorTransfer.init(0); + } else { + vectorTransfer.init(getTotalLiveDocsCount(values)); + } + int doc = values.nextDoc(); + for (; doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); serializationMode = vectorTransfer.getSerializationMode(bytesref); vectorTransfer.transfer(bytesref); docIdList.add(doc); + // Semi-hacky way to check if the streaming limit has been reached + if (iterative && vectorTransfer.getVectorAddress() != 0) { + break; + } } vectorTransfer.close(); - return new KNNCodecUtil.Pair( + + boolean finished = doc == DocIdSetIterator.NO_MORE_DOCS; + + return new KNNCodecUtil.VectorBatch( docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorTransfer.getVectorAddress(), vectorTransfer.getDimension(), - serializationMode + serializationMode, + finished ); } @@ -115,7 +133,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } - private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { long totalLiveDocs; if (binaryDocValues instanceof KNN80BinaryDocValues) { totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 1f23f6fcd..c00fbfb72 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -49,6 +49,24 @@ class FaissService { }); } + public static native long initIndex(long numDocs, int dim, Map parameters); + + public static native long initBinaryIndex(long numDocs, int dim, Map parameters); + + public static native void insertToIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, Map parameters); + + public static native void insertToBinaryIndex( + int[] ids, + long vectorsAddress, + int dim, + long indexAddress, + Map parameters + ); + + public static native void writeIndex(long indexAddress, String indexPath, Map parameters); + + public static native void writeBinaryIndex(long indexAddress, String indexPath, Map parameters); + /** * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 2a8d3ea8f..7c56736a5 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -23,6 +23,55 @@ * Service to distribute requests to the proper engine jni service */ public class JNIService { + private static final String FAISS_BINARY_INDEX_PREFIX = "B"; + + public static long initIndexFromScratch(long size, int dim, Map parameters, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.initBinaryIndex(size, dim, parameters); + } else { + return FaissService.initIndex(size, dim, parameters); + } + } + + throw new IllegalArgumentException( + String.format("initIndexFromScratch not supported for provided engine : %s", knnEngine.getName()) + ); + } + + public static void insertToIndex( + int[] docs, + long vectorAddress, + int dimension, + Map parameters, + long indexAddress, + KNNEngine knnEngine + ) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.insertToBinaryIndex(docs, vectorAddress, dimension, indexAddress, parameters); + } else { + FaissService.insertToIndex(docs, vectorAddress, dimension, indexAddress, parameters); + } + return; + } + + throw new IllegalArgumentException(String.format("insertToIndex not supported for provided engine : %s", knnEngine.getName())); + } + + public static void writeIndex(String indexPath, long indexAddress, KNNEngine knnEngine, Map parameters) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.writeBinaryIndex(indexAddress, indexPath, parameters); + } else { + FaissService.writeIndex(indexAddress, indexPath, parameters); + } + return; + } + + throw new IllegalArgumentException(String.format("writeIndex not supported for provided engine : %s", knnEngine.getName())); + } + /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 0ff2d0516..a9cdcb34e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -241,7 +241,7 @@ public int advance(int target) throws IOException { @Override public long cost() { - return 0; + return this.count; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 2ff0f08e5..50ae8d139 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -39,7 +39,7 @@ public void testGetPair_whenCalled_thenReturn() { when(vectorTransfer.getDimension()).thenReturn(dimension); // Run - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(binaryDocValues, vectorTransfer); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(binaryDocValues, vectorTransfer, false); // Verify verify(vectorTransfer).init(liveDocCount); @@ -47,9 +47,9 @@ public void testGetPair_whenCalled_thenReturn() { verify(vectorTransfer).transfer(any(BytesRef.class)); verify(vectorTransfer).close(); - assertTrue(Arrays.equals(docId, pair.docs)); - assertEquals(vectorAddress, pair.getVectorAddress()); - assertEquals(dimension, pair.getDimension()); - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); + assertTrue(Arrays.equals(docId, batch.docs)); + assertEquals(vectorAddress, batch.getVectorAddress()); + assertEquals(dimension, batch.getDimension()); + assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, batch.serializationMode); } }