From 629757b5d07d18182536080a93c74f14375fb2d7 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Mon, 8 Jul 2024 17:05:11 -0700 Subject: [PATCH] Use data type check when create index from template and fix jni errors Signed-off-by: Junqiu Lei --- .../org_opensearch_knn_jni_FaissService.h | 8 -------- jni/src/faiss_wrapper.cpp | 12 +++++------- .../org_opensearch_knn_jni_FaissService.cpp | 14 +------------- .../KNN80Codec/KNN80DocValuesConsumer.java | 18 +++++++++++++++--- .../opensearch/knn/indices/ModelMetadata.java | 6 +++--- .../org/opensearch/knn/jni/JNIService.java | 2 ++ 6 files changed, 26 insertions(+), 34 deletions(-) diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index b77fd71d4b..fa20371c49 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -116,14 +116,6 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter (JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray); -/* - * Class: org_opensearch_knn_jni_FaissService - * Method: queryBIndexWithFilter - * Signature: (J[BI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; - */ -JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter - (JNIEnv *, jclass, jlong, jbyteArray, jint, jlongArray, jint, jintArray); - /* * Class: org_opensearch_knn_jni_FaissService * Method: free diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 37cce029a8..92393245ee 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -155,8 +155,6 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN } // end parameters to pass - std::cout << "Index description in CreateIndex: " << indexDescriptionCpp << std::endl; - // Create index indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp); } @@ -263,7 +261,10 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter // Read vectors from memory address auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); int dim = (int)dimJ; - int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); @@ -285,7 +286,7 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + idMap.add_with_ids(numVectors, reinterpret_cast(inputVectors->data()), idVector.data()); // Releasing the vectorsAddressJ memory as that is not required once we have created the index. // This is not the ideal approach, please refer this gh issue for long term solution: // https://github.com/opensearch-project/k-NN/issues/1600 @@ -647,7 +648,6 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); - std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl; std::unique_ptr indexWriter; indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric)); @@ -714,7 +714,6 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); - std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl; std::unique_ptr indexWriter; indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); @@ -809,7 +808,6 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { if (auto * indexIvf = dynamic_cast(index)) { - std::cout << "Index is IVFBinary" << std::endl; indexIvf->make_direct_map(); } if (!index->is_trained) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index f36a9d2f3e..2394e2951f 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -194,19 +194,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin } -JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter - (JNIEnv * env, jclass cls, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { - - try { - return knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } - return nullptr; - -} - -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ) { try { return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index de05ff20f5..d48b2df6a6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -112,10 +112,22 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { } private VectorTransfer getVectorTransfer(FieldInfo field) { - if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + boolean isBinary = false; + + // Check if the field has a model ID and retrieve the model's vector data type + if (field.attributes().containsKey(MODEL_ID)) { + Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID)); + isBinary = model.getModelMetadata().getVectorDataType() == VectorDataType.BINARY; + } else if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + isBinary = true; + } + + // Return the appropriate VectorTransfer instance based on the vector data type + if (isBinary) { return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } else { + return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) @@ -194,7 +206,7 @@ private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNN KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine && SpaceType.HAMMING_BIT.equals(model.getModelMetadata().getSpaceType())) { + if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { parameters.put( KNNConstants.INDEX_DESCRIPTION_PARAMETER, FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + model.getModelMetadata().getMethodComponentContext().getName().toUpperCase() diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index 1498987fd6..90df1cb559 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -84,7 +84,7 @@ public ModelMetadata(StreamInput in) throws IOException { this.methodComponentContext = MethodComponentContext.EMPTY; } - if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_DATA_TYPE)) { + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.METHOD_PARAMETER)) { this.vectorDataType = VectorDataType.get(in.readOptionalString()); } else { this.vectorDataType = VectorDataType.FLOAT; @@ -461,7 +461,7 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { getMethodComponentContext().writeTo(out); } - if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_DATA_TYPE)) { + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.METHOD_PARAMETER)) { out.writeOptionalString(vectorDataType.getValue()); } } @@ -484,7 +484,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws getMethodComponentContext().toXContent(builder, params); builder.endObject(); } - if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_DATA_TYPE)) { + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.METHOD_PARAMETER)) { builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); } return builder; diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 4a46e4d055..9958736bc8 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -87,8 +87,10 @@ public static void createIndexFromTemplate( if (KNNEngine.FAISS == knnEngine) { if (faissUtil.isBinaryIndex(parameters)) { FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; } else { FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; } }