diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index dc5725c13e..92a8c39ca8 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -27,7 +27,13 @@ namespace knn_jni { // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, - jobject parametersJ, IndexService* indexService); + jobject parametersJ); + + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index + // based off of the template index passed in. The index is serialized to indexPathJ. + void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jobject parametersJ); // Load an index from indexPathJ into memory. // diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index d9395fb600..e9f2a6c484 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -163,7 +163,7 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jbyteArray templateIndexJ, jobject parametersJ, IndexService* indexService) { + jbyteArray templateIndexJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } @@ -192,6 +192,7 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * } jniUtil->DeleteLocalRef(env, parametersJ); + // Read data set // Read vectors from memory address auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); int dim = (int)dimJ; @@ -211,28 +212,87 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * } jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index(&idMap, indexPathCpp.c_str()); +} +void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } - std::vector indexBytes(indexBytesJ, indexBytesJ + indexBytesCount); - jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } - // Convert IDs to vector - auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } - // Index path - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } - // Template index path - std::string templateIndexPathCpp(reinterpret_cast(indexBytes.data()), indexBytes.size()); + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } - // Create index from template - indexService->createIndexFromTemplate(jniUtil, env, vectorIoReader, idVector, numVectors, inputVectors, indexPathCpp); + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + // Read data set + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0)); + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); // Releasing the vectorsAddressJ memory as that is not required once we have created the index. // This is not the ideal approach, please refer this gh issue for long term solution: // https://github.com/opensearch-project/k-NN/issues/1600 delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index_binary(&idMap, indexPathCpp.c_str()); } jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index c38100507f..fcc87030ad 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -84,10 +84,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT jobject parametersJ) { try { - std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); - knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -102,9 +99,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde jobject parametersJ) { try { - std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); - knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &binaryIndexService); + knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index c8c36da4d8..030e10f752 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -148,15 +148,11 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; - std::unique_ptr faissMethods(new FaissMethods()); - knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndexFromTemplate( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), (jlong)vectors, dim, (jstring)&indexPath, reinterpret_cast(&(vectorIoWriter.data)), - (jobject) ¶metersMap, - &indexService + (jobject) ¶metersMap ); // Make sure index can be loaded