diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9c..dcfc3b7873 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -20,6 +20,10 @@ #include "faiss_methods.h" #include +namespace faiss { + struct VectorIOReader; +} + namespace knn_jni { namespace faiss_wrapper { @@ -61,6 +65,16 @@ class IndexService { std::vector ids, std::string indexPath, std::unordered_map parameters); + + virtual void createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::VectorIOReader vectorIoReader, + std::vector idVector, + int numVectors, + std::vector *inputVectors, + std::string& indexPathCpp); + virtual ~IndexService() = default; protected: std::unique_ptr faissMethods; @@ -103,6 +117,16 @@ class BinaryIndexService : public IndexService { std::string indexPath, std::unordered_map parameters ) override; + + virtual void createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::VectorIOReader vectorIoReader, + std::vector idVector, + int numVectors, + std::vector *inputVectors, + std::string& indexPathCpp) override; + virtual ~BinaryIndexService() = default; }; diff --git a/jni/include/faiss_methods.h b/jni/include/faiss_methods.h index 38d8d756a7..a09a97e82c 100644 --- a/jni/include/faiss_methods.h +++ b/jni/include/faiss_methods.h @@ -32,6 +32,8 @@ class FaissMethods { virtual faiss::IndexIDMapTemplate* indexBinaryIdMap(faiss::IndexBinary* index); virtual void writeIndex(const faiss::Index* idx, const char* fname); virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname); + virtual faiss::Index* readIndex(const char* indexPath); + virtual faiss::IndexBinary* readIndexBinary(const char* indexPath); virtual ~FaissMethods() = default; }; diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 89986a231d..dc5725c13e 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -27,7 +27,7 @@ namespace knn_jni { // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, - jobject parametersJ); + jobject parametersJ, IndexService* indexService); // Load an index from indexPathJ into memory. // @@ -102,6 +102,13 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + // Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with + // the vector of floats located at trainVectorsPointerJ. + // + // Return the serialized representation + jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index d416054346..7ebf5e2284 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -35,6 +35,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate @@ -43,6 +44,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndexFromTemplate + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + */ + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -147,6 +156,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex (JNIEnv *, jclass, jobject, jint, jlong); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: trainBinaryIndex + * Signature: (Ljava/util/Map;IJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex + (JNIEnv *, jclass, jobject, jint, jlong); + /* * Class: org_opensearch_knn_jni_FaissService * Method: transferVectors diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af2..8e919dfe7d 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace knn_jni { namespace faiss_wrapper { @@ -106,6 +107,31 @@ void IndexService::createIndex( faissMethods->writeIndex(idMap.get(), indexPath.c_str()); } +void IndexService::createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::VectorIOReader vectorIoReader, + std::vector idVector, + int numVectors, + std::vector *inputVectors, + std::string& indexPathCpp) { + // Read vectors from memory address + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); + + faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); + + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + + faiss::write_index(&idMap, indexPathCpp.c_str()); +} + BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} void BinaryIndexService::createIndex( @@ -160,5 +186,35 @@ void BinaryIndexService::createIndex( faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); } +void BinaryIndexService::createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::VectorIOReader vectorIoReader, + std::vector idVector, + int numVectors, + std::vector *inputVectors, + std::string& indexPathCpp) { + // Read vectors from memory address + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(dynamic_cast(faiss::read_index(&vectorIoReader, 0))); + + // faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + + // idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + + auto* vectorData = reinterpret_cast(inputVectors->data()); + idMap->add_with_ids(numVectors, vectorData, idVector.data()); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + + // Write the index to disk + faissMethods->writeIndexBinary(idMap.get(), indexPathCpp.c_str()); +} + } // namespace faiss_wrapper -} // namesapce knn_jni +} // namesapce knn_jni \ No newline at end of file diff --git a/jni/src/faiss_methods.cpp b/jni/src/faiss_methods.cpp index 05c8f459ae..c2653896dd 100644 --- a/jni/src/faiss_methods.cpp +++ b/jni/src/faiss_methods.cpp @@ -32,9 +32,18 @@ faiss::IndexIDMapTemplate* FaissMethods::indexBinaryIdMap(fa void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) { faiss::write_index(idx, fname); } + void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) { faiss::write_index_binary(idx, fname); } +faiss::Index* FaissMethods::readIndex(const char* indexPath) { + return faiss::read_index(indexPath); +} + +faiss::IndexBinary* FaissMethods::readIndexBinary(const char* indexPath) { + return reinterpret_cast(faiss::read_index_binary(indexPath)); +} + } // namespace faiss_wrapper } // namesapce knn_jni diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 3eda03b419..d9395fb600 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); +// Train a binary index with data provided +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -152,13 +155,15 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN } // end parameters to pass + std::cout << "Index description in CreateIndex: " << indexDescriptionCpp << std::endl; + // Create index indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp); } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jbyteArray templateIndexJ, jobject parametersJ) { + jbyteArray templateIndexJ, jobject parametersJ, IndexService* indexService) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } @@ -187,7 +192,6 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * } jniUtil->DeleteLocalRef(env, parametersJ); - // Read data set // Read vectors from memory address auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); int dim = (int)dimJ; @@ -207,20 +211,28 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * } jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); - // Create faiss index - std::unique_ptr indexWriter; - indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); - auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + + + std::vector indexBytes(indexBytesJ, indexBytesJ + indexBytesCount); + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Convert IDs to vector + auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + + // Index path + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + + // Template index path + std::string templateIndexPathCpp(reinterpret_cast(indexBytes.data()), indexBytes.size()); + + // Create index from template + indexService->createIndexFromTemplate(jniUtil, env, vectorIoReader, idVector, numVectors, inputVectors, indexPathCpp); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. // This is not the ideal approach, please refer this gh issue for long term solution: // https://github.com/opensearch-project/k-NN/issues/1600 delete inputVectors; - // Write the index to disk - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index(&idMap, indexPathCpp.c_str()); } jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { @@ -568,6 +580,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl; std::unique_ptr indexWriter; indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric)); @@ -617,6 +630,58 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti return ret; } +jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl; + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Train index if needed + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + if(!indexWriter->is_trained) { + InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Now that indexWriter is trained, we just load the bytes into an array and return + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Wrap in smart pointer + std::unique_ptr jbytesBuffer; + jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); + int c = 0; + for (auto b : vectorIoWriter.data) { + jbytesBuffer[c++] = (jbyte) b; + } + + jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + return ret; +} + faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { if (spaceType == knn_jni::L2) { return faiss::METRIC_L2; @@ -675,6 +740,16 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + std::cout << "Index is IVFBinary" << std::endl; + indexIvf->make_direct_map(); + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index b868946fe3..c38100507f 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -84,7 +84,27 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &binaryIndexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -232,6 +252,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex return nullptr; } +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls, + jobject parametersJ, + jint dimensionJ, + jlong trainVectorsPointerJ) +{ + try { + return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ, jobjectArray vectorsJ) diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 030e10f752..c8c36da4d8 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -148,11 +148,15 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + std::unique_ptr faissMethods(new FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::CreateIndexFromTemplate( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), (jlong)vectors, dim, (jstring)&indexPath, reinterpret_cast(&(vectorIoWriter.data)), - (jobject) ¶metersMap + (jobject) ¶metersMap, + &indexService ); // Make sure index can be loaded diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index c005b2279c..6861751d68 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -154,7 +154,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); + indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } @@ -188,18 +188,26 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = ImmutableMap.of( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine && SpaceType.HAMMING_BIT.equals(model.getModelMetadata().getSpaceType())) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + model.getModelMetadata().getMethodComponentContext().getName().toUpperCase() + ); + } + AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, - model, + model.getModelBlob(), parameters, knnEngine ); diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 711c206f50..4e39c1af18 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -305,7 +305,7 @@ public class Faiss extends NativeLibrary { return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; }) .build() - ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.HAMMING_BIT).build() ); final static Faiss INSTANCE = new Faiss( diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index dbf13028db..5ad982af56 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -96,6 +96,25 @@ public static native void createIndexFromTemplate( Map parameters ); + /** + * Create a binary index for the native library with a provided template index + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters additional build time parameters + */ + public static native void createBinaryIndexFromTemplate( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -249,6 +268,16 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( */ public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** + * Train an empty binary index + * + * @param indexParameters parameters used to build index + * @param dimension dimension for the index + * @param trainVectorsPointer pointer to where training vectors are stored in native memory + * @return bytes array of trained template index + */ + public static native byte[] trainBinaryIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** *

* The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index fa3a29e3af..d5dce7f114 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -86,8 +86,11 @@ public static void createIndexFromTemplate( KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); - return; + if (faissUtil.isBinaryIndex(parameters)) { + FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + } else { + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + } } throw new IllegalArgumentException( @@ -301,7 +304,11 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE */ public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) { if (KNNEngine.FAISS == knnEngine) { - return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + if (faissUtil.isBinaryIndex(indexParameters)) { + return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer); + } else { + return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + } } throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index aa2786c0a2..0486bec6fc 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -17,6 +17,7 @@ import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -32,6 +33,8 @@ import java.util.Map; import java.util.Objects; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + /** * Encapsulates all information required to generate and train a model. */ @@ -182,6 +185,13 @@ public void run() { KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + if (modelMetadata.getSpaceType().equals(SpaceType.HAMMING_BIT)) { + trainParameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + trainParameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + } + byte[] modelBlob = JNIService.trainIndex( trainParameters, model.getModelMetadata().getDimension(),