diff --git a/jni/cmake/init-faiss.cmake-e b/jni/cmake/init-faiss.cmake-e new file mode 100644 index 0000000000..bef93eda00 --- /dev/null +++ b/jni/cmake/init-faiss.cmake-e @@ -0,0 +1,69 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# + +# Check if faiss exists +find_path(FAISS_REPO_DIR NAMES faiss PATHS ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss NO_DEFAULT_PATH) + +# If not, pull the updated submodule +if (NOT EXISTS ${FAISS_REPO_DIR}) + message(STATUS "Could not find faiss. Pulling updated submodule.") + execute_process(COMMAND git submodule update --init -- external/faiss WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +endif () + +# Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu. +find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH) + +# If it exists, apply patches +if (EXISTS ${PATCH_FILE}) + message(STATUS "Applying custom patches.") + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + if(RESULT_CODE) + message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}") + endif() +endif() + +if (${CMAKE_SYSTEM_NAME} STREQUAL Darwin) + if(CMAKE_C_COMPILER_ID MATCHES "Clang\$") + set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp") + set(OpenMP_C_LIB_NAMES "omp") + set(OpenMP_omp_LIBRARY /usr/local/opt/libomp/lib/libomp.dylib) + endif() + + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang\$") + set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include") + set(OpenMP_CXX_LIB_NAMES "omp") + set(OpenMP_omp_LIBRARY /usr/local/opt/libomp/lib/libomp.dylib) + endif() +endif() + +find_package(ZLIB REQUIRED) + +# Statically link BLAS - ensure this is before we find the blas package so we dont dynamically link +set(BLA_STATIC ON) +find_package(BLAS REQUIRED) +enable_language(Fortran) +find_package(LAPACK REQUIRED) + +# Set relevant properties +set(BUILD_TESTING OFF) # Avoid building faiss tests +set(FAISS_ENABLE_GPU OFF) +set(FAISS_ENABLE_PYTHON OFF) + +if(NOT DEFINED SIMD_ENABLED) + set(SIMD_ENABLED true) # set default value as true if the argument is not set +endif() + +if(${CMAKE_SYSTEM_NAME} STREQUAL Windows OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64" OR NOT ${SIMD_ENABLED}) + set(FAISS_OPT_LEVEL generic) # Keep optimization level as generic on Windows OS as it is not supported due to MINGW64 compiler issue. Also, on aarch64 avx2 is not supported. + set(TARGET_LINK_FAISS_LIB faiss) +else() + set(FAISS_OPT_LEVEL avx2) # Keep optimization level as avx2 to improve performance on Linux and Mac. + set(TARGET_LINK_FAISS_LIB faiss_avx2) + string(PREPEND LIB_EXT "_avx2") # Prepend "_avx2" to lib extension to create the library as "libopensearchknn_faiss_avx2.so" on linux and "libopensearchknn_faiss_avx2.jnilib" on mac +endif() + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9c..9de374950d 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -31,8 +31,50 @@ namespace faiss_wrapper { class IndexService { public: IndexService(std::unique_ptr faissMethods); - //TODO Remove dependency on JNIUtilInterface and JNIEnv - //TODO Reduce the number of parameters + /** + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return pointer to the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters); + /** + * Add vectors to index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress, std::unordered_map parameters); + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data + * @param indexPath path to write index + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void writeIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int threadCount, std::string indexPath, jlong idMapAddress, std::unordered_map parameters); + // TODO Remove dependency on JNIUtilInterface and JNIEnv + // TODO Reduce the number of parameters /** * Create index @@ -75,6 +117,48 @@ class BinaryIndexService : public IndexService { //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters BinaryIndexService(std::unique_ptr faissMethods); + /** + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return pointer to the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + /** + * Add vectors to index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress, std::unordered_map parameters) override; + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data + * @param indexPath path to write index + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void writeIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int threadCount, std::string indexPath, jlong idMapAddress, std::unordered_map parameters) override; /** * Create binary index * diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2b9bc2c767..61e14b196f 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -18,6 +18,12 @@ namespace knn_jni { namespace faiss_wrapper { + jlong InitIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, IndexService *indexService); + + void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong index_ptr, jobject parametersJ, IndexService *indexService); + + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jlong index_ptr, jobject parametersJ, IndexService *indexService); + // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 7cc071ff38..3458c3c0ee 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,6 +18,54 @@ #ifdef __cplusplus extern "C" { #endif +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/.idea/modules.xml b/jni/src/.idea/modules.xml new file mode 100644 index 0000000000..f669a0e594 --- /dev/null +++ b/jni/src/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/jni/src/.idea/vcs.xml b/jni/src/.idea/vcs.xml new file mode 100644 index 0000000000..b2bdec2d71 --- /dev/null +++ b/jni/src/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/jni/src/.idea/workspace.xml b/jni/src/.idea/workspace.xml new file mode 100644 index 0000000000..45289f4a4d --- /dev/null +++ b/jni/src/.idea/workspace.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + 1721418022030 + + + + \ No newline at end of file diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af2..78f8fcb3b4 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -57,6 +57,109 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} +jlong IndexService::initIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Removed the check for number of vectors + // Don't use unique_ptr here since we want to access the index in the future. + faiss::Index * indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + faiss::IndexIDMap * idMap(faissMethods->indexIdMap(indexWriter)); + + faiss::IndexHNSW * hnsw = dynamic_cast(idMap->index); + + if(hnsw != NULL) { + faiss::IndexFlat * storage = dynamic_cast(hnsw->storage); + if(storage != NULL) { + storage->codes.reserve(dim * numVectors); + } + } + + return (jlong)idMap; +} + +void IndexService::insertToIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); + + // Add vectors + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} + +void IndexService::writeIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int threadCount, + std::string indexPath, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); + + // Write the index to disk + faissMethods->writeIndex(idMap, indexPath.c_str()); + + // Free the memory used by the index + delete idMap; +} + void IndexService::createIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -108,6 +211,109 @@ void IndexService::createIndex( BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} +jlong BinaryIndexService::initIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Removed the check for number of vectors + // Don't use unique_ptr here since we want to access the index in the future. + faiss::IndexBinary * indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + faiss::IndexBinaryIDMap * idMap(faissMethods->indexBinaryIdMap(indexWriter)); + + faiss::IndexBinaryHNSW * hnsw = dynamic_cast(idMap->index); + + if(hnsw != NULL) { + faiss::IndexBinaryFlat * storage = dynamic_cast(hnsw->storage); + if(storage != NULL) { + storage->xb.reserve(dim / 8 * numVectors); + } + } + + return (jlong)idMap; +} + +void BinaryIndexService::insertToIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexBinaryIDMap * idMap = reinterpret_cast (idMapAddress); + + // Add vectors + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} + +void BinaryIndexService::writeIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int threadCount, + std::string indexPath, + jlong idMapAddress, + std::unordered_map parameters + ) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexBinaryIDMap * idMap = reinterpret_cast (idMapAddress); + + // Write the index to disk + faissMethods->writeIndexBinary(idMap, indexPath.c_str()); + + // Free the memory used by the index + delete idMap; +} + void BinaryIndexService::createIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 45830aff6c..111bf25d8d 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -85,6 +85,220 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); +jlong knn_jni::faiss_wrapper::InitIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, IndexService* indexService) { + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Dimension + int dim = (int)dimJ; + + // Number of docs + int docs = (int)numDocs; + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); +} +/* +jlong knn_jni::faiss_wrapper::InitIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ, IndexService* indexService) { + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Dimension + int dim = (int)dimJ; + + // Number of docs + int docs = (int)numDocs; + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); +} +*/ +void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong index_ptr, jobject parametersJ, IndexService* indexService) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Dimension + int dim = (int)dimJ; + + // Number of vectors + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Vectors address + int64_t vectorsAddress = (int64_t)vectorsAddressJ; + + // Ids + auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + indexService->insertToIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, index_ptr, subParametersCpp); +} + +void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, + jstring indexPathJ, jlong index_ptr, jobject parametersJ, IndexService* indexService) { + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Parameters to pass + // Metric type + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); + + // Index description + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + + // Index path + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + indexService->writeIndex(jniUtil, env, metric, indexDescriptionCpp, threadCount, indexPathCpp, index_ptr, subParametersCpp); +} + void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ, IndexService* indexService) { if (idsJ == nullptr) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 6e447b0347..b59e56a47c 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -39,6 +39,96 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { jniUtil.Uninitialize(env); } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, parametersJ, &indexService); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete reinterpret_cast*>(vectorsAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, parametersJ, &binaryIndexService); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete reinterpret_cast*>(vectorsAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, parametersJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 50c1c92715..1afef70d0d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -120,18 +120,15 @@ private VectorTransfer getVectorTransfer(FieldInfo field) { public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) throws IOException { - // Get values to be indexed - BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(values, getVectorTransfer(field)); - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); + // Don't know why this would be null. Can't do any creation when it is, so best to just return. + if (field == null) { + logger.info("Field is null!\n"); return; } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + // Get values to be indexed + BinaryDocValues values = valuesProducer.getBinary(field); if (isMerge) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); } // Increment counter for number of graph index requests KNNCounter.GRAPH_INDEX_REQUESTS.increment(); @@ -146,35 +143,29 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName ).toString(); - NativeIndexCreator indexCreator; // Create library index either from model or from scratch - if (field.attributes().containsKey(MODEL_ID)) { - String modelId = field.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); - } else { - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } - - if (isMerge) { - recordMergeStats(pair.docs.length, arraySize); - } - - if (isRefresh) { - recordRefreshStats(); - } // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will // not be marked as added to the directory. + state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); + boolean fromScratch = !field.attributes().containsKey(MODEL_ID); + boolean iterative = fromScratch && KNNEngine.FAISS == knnEngine; + createKNNIndex(field, values, knnEngine, indexPath, fromScratch, iterative, isMerge); + + if (isRefresh) { + recordRefreshStats(); + } writeFooter(indexPath, engineFileName); } + private void currentMergeStats(int length, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(length); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + } + private void recordMergeStats(int length, long arraySize) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); @@ -188,77 +179,208 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = ImmutableMap.of( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + private Map genParameters(boolean fromScratch, FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { + Map parameters; + if (fromScratch) { + parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put( + KNNConstants.SPACE_TYPE, + fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()) + ); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine + && VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())) + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + } + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + } else { + parameters = ImmutableMap.of( + KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) + ); + } + return parameters; + } + + private long initIndexFromScratch(long size, int dim, KNNEngine knnEngine, Map parameters) throws IOException { + // Pass the path for the nms library to save the file + return AccessController.doPrivileged((PrivilegedAction) () -> { + return JNIService.initIndexFromScratch(size, dim, parameters, knnEngine); + }); + } + + private void insertToIndex(KNNCodecUtil.VectorBatch pair, KNNEngine knnEngine, long indexAddress, Map parameters) + throws IOException { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), parameters, indexAddress, knnEngine); + return null; + }); + } + + private void writeIndex(long indexAddress, String indexPath, KNNEngine knnEngine, Map parameters) throws IOException { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexPath, indexAddress, knnEngine, parameters); + return null; + }); + } + + private void createKNNIndexFromTemplate( + FieldInfo field, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + Map parameters, + boolean isMerge + ) throws IOException { + String modelId = field.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + byte[] modelBlob = model.getModelBlob(); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, getVectorTransfer(field), false); + + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(values); + long arraySize = calculateArraySize(numDocs, batch.getDimension(), batch.serializationMode); + + if (isMerge) { + currentMergeStats(numDocs, arraySize); + } + AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( - pair.docs, - pair.getVectorAddress(), - pair.getDimension(), + batch.docs, + batch.getVectorAddress(), + batch.getDimension(), indexPath, - model, + modelBlob, parameters, knnEngine ); return null; }); + + if (isMerge) { + recordMergeStats(numDocs, arraySize); + } } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { - Map parameters = new HashMap<>(); - Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); - - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } + private void createKNNIndexFromScratch( + FieldInfo fieldInfo, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + Map parameters, + boolean isMerge + ) throws IOException { + VectorTransfer transfer = getVectorTransfer(fieldInfo); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, transfer, false); - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); - } + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(values); + long arraySize = calculateArraySize(numDocs, batch.getDimension(), batch.serializationMode); - // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine - && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())) - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + if (isMerge) { + currentMergeStats(numDocs, arraySize); } - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - // Pass the path for the nms library to save the file AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); + JNIService.createIndex(batch.docs, batch.getVectorAddress(), batch.getDimension(), indexPath, parameters, knnEngine); return null; }); + + if (isMerge) { + recordMergeStats(numDocs, arraySize); + } + } + + private void createKNNIndexFromScratchIteratively( + FieldInfo fieldInfo, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + Map parameters, + boolean isMerge + ) throws IOException { + VectorTransfer transfer = getVectorTransfer(fieldInfo); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, transfer, true); + + int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(values); + long arraySize = calculateArraySize(numDocs, batch.getDimension(), batch.serializationMode); + + if (isMerge) { + currentMergeStats(numDocs, arraySize); + } + + long indexAddress = initIndexFromScratch(numDocs, batch.getDimension(), knnEngine, parameters); + for(; !batch.finished; batch = KNNCodecUtil.getVectorBatch(values, transfer, true)) { + insertToIndex(batch, knnEngine, indexAddress, parameters); + } + writeIndex(indexAddress, indexPath, knnEngine, parameters); + if (isMerge) { + recordMergeStats(numDocs, arraySize); + } + } + + private void createKNNIndex( + FieldInfo fieldInfo, + BinaryDocValues values, + KNNEngine knnEngine, + String indexPath, + boolean fromScratch, + boolean iterative, + boolean isMerge + ) throws IOException { + Map parameters = genParameters(fromScratch, fieldInfo, knnEngine); + if (fromScratch && iterative) { + createKNNIndexFromScratchIteratively(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } else if (fromScratch) { + createKNNIndexFromScratch(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } else { + createKNNIndexFromTemplate(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } + /* + if(fromScratch) { + createKNNIndexFromScratch(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } else { + createKNNIndexFromTemplate(fieldInfo, values, knnEngine, indexPath, parameters, isMerge); + } + */ } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 04aeb337fd..c9d2688cf5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -29,7 +29,7 @@ public class KNNCodecUtil { public static final int JAVA_ROUNDING_NUMBER = 8; @AllArgsConstructor - public static final class Pair { + public static final class VectorBatch { public int[] docs; @Getter @Setter @@ -38,32 +38,43 @@ public static final class Pair { @Setter private int dimension; public SerializationMode serializationMode; + public boolean finished; } /** * Extract docIds and vectors from binary doc values. - * - * @param values Binary doc values - * @param vectorTransfer Utility to make transfer - * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors - * @throws IOException thrown when unable to get binary of vectors - */ - public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException { + * + * @param values Binary doc values + * @param vectorTransfer Utility to make transfer + * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors + * @throws IOException thrown when unable to get binary of vectors + */ + public static KNNCodecUtil.VectorBatch getVectorBatch( + final BinaryDocValues values, + final VectorTransfer vectorTransfer, + boolean iterative + ) throws IOException { List docIdList = new ArrayList<>(); SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; vectorTransfer.init(getTotalLiveDocsCount(values)); - for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + int doc = values.nextDoc(); + for (; doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); serializationMode = vectorTransfer.getSerializationMode(bytesref); vectorTransfer.transfer(bytesref); docIdList.add(doc); + // Semi-hacky way to check if the streaming limit has been reached + if (iterative && vectorTransfer.getVectorAddress() != 0) { + break; + } } vectorTransfer.close(); - return new KNNCodecUtil.Pair( + return new KNNCodecUtil.VectorBatch( docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorTransfer.getVectorAddress(), vectorTransfer.getDimension(), - serializationMode + serializationMode, + doc == DocIdSetIterator.NO_MORE_DOCS ); } @@ -115,7 +126,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } - private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { long totalLiveDocs; if (binaryDocValues instanceof KNN80BinaryDocValues) { totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 21de907657..8fc39a5acf 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -49,6 +49,24 @@ class FaissService { }); } + public static native long initIndex(long numDocs, int dim, Map parameters); + + public static native long initBinaryIndex(long numDocs, int dim, Map parameters); + + public static native void insertToIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, Map parameters); + + public static native void insertToBinaryIndex( + int[] ids, + long vectorsAddress, + int dim, + long indexAddress, + Map parameters + ); + + public static native void writeIndex(long indexAddress, String indexPath, Map parameters); + + public static native void writeBinaryIndex(long indexAddress, String indexPath, Map parameters); + /** * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index cefd0af53e..e221bf8d77 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -23,6 +23,55 @@ * Service to distribute requests to the proper engine jni service */ public class JNIService { + private static final String FAISS_BINARY_INDEX_PREFIX = "B"; + + public static long initIndexFromScratch(long size, int dim, Map parameters, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.initBinaryIndex(size, dim, parameters); + } else { + return FaissService.initIndex(size, dim, parameters); + } + } + + throw new IllegalArgumentException( + String.format("initIndexFromScratch not supported for provided engine : %s", knnEngine.getName()) + ); + } + + public static void insertToIndex( + int[] docs, + long vectorAddress, + int dimension, + Map parameters, + long indexAddress, + KNNEngine knnEngine + ) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.insertToBinaryIndex(docs, vectorAddress, dimension, indexAddress, parameters); + } else { + FaissService.insertToIndex(docs, vectorAddress, dimension, indexAddress, parameters); + } + return; + } + + throw new IllegalArgumentException(String.format("insertToIndex not supported for provided engine : %s", knnEngine.getName())); + } + + public static void writeIndex(String indexPath, long indexAddress, KNNEngine knnEngine, Map parameters) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.writeBinaryIndex(indexAddress, indexPath, parameters); + } else { + FaissService.writeIndex(indexAddress, indexPath, parameters); + } + return; + } + + throw new IllegalArgumentException(String.format("writeIndex not supported for provided engine : %s", knnEngine.getName())); + } + /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 2ff0f08e51..50ae8d139e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -39,7 +39,7 @@ public void testGetPair_whenCalled_thenReturn() { when(vectorTransfer.getDimension()).thenReturn(dimension); // Run - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(binaryDocValues, vectorTransfer); + KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(binaryDocValues, vectorTransfer, false); // Verify verify(vectorTransfer).init(liveDocCount); @@ -47,9 +47,9 @@ public void testGetPair_whenCalled_thenReturn() { verify(vectorTransfer).transfer(any(BytesRef.class)); verify(vectorTransfer).close(); - assertTrue(Arrays.equals(docId, pair.docs)); - assertEquals(vectorAddress, pair.getVectorAddress()); - assertEquals(dimension, pair.getDimension()); - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); + assertTrue(Arrays.equals(docId, batch.docs)); + assertEquals(vectorAddress, batch.getVectorAddress()); + assertEquals(dimension, batch.getDimension()); + assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, batch.serializationMode); } }