From 648f3422cc6e1d19a4026a7d51fdab3e16ae5c82 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Tue, 25 Jun 2024 13:16:26 -0700 Subject: [PATCH] Add binary format support for Faiss IVF Signed-off-by: Junqiu Lei --- jni/include/faiss_wrapper.h | 19 +++ .../org_opensearch_knn_jni_FaissService.h | 16 +++ jni/src/faiss_wrapper.cpp | 133 ++++++++++++++++++ .../org_opensearch_knn_jni_FaissService.cpp | 28 ++++ .../KNN80Codec/KNN80DocValuesConsumer.java | 38 +++-- .../index/mapper/KNNVectorFieldMapper.java | 12 +- .../knn/index/mapper/ModelFieldMapper.java | 8 +- .../org/opensearch/knn/index/util/Faiss.java | 2 +- .../org/opensearch/knn/indices/ModelDao.java | 1 + .../opensearch/knn/indices/ModelMetadata.java | 49 +++++-- .../org/opensearch/knn/jni/FaissService.java | 29 ++++ .../org/opensearch/knn/jni/JNIService.java | 15 +- .../plugin/rest/RestTrainModelHandler.java | 22 +-- .../transport/TrainingModelRequest.java | 11 +- .../TrainingModelTransportAction.java | 3 +- .../opensearch/knn/training/TrainingJob.java | 16 ++- 16 files changed, 360 insertions(+), 42 deletions(-) diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2b9bc2c767..d25bf8f7c4 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -29,6 +29,12 @@ namespace knn_jni { jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index + // based off of the template index passed in. The index is serialized to indexPathJ. + void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jobject parametersJ); + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index @@ -80,6 +86,12 @@ namespace knn_jni { jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Execute a query against the binary index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer, jboolean isBinaryIndexJ); @@ -96,6 +108,13 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + // Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with + // the vector of floats located at trainVectorsPointerJ. + // + // Return the serialized representation + jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 7cc071ff38..025fb12e88 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndexFromTemplate + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + */ + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex (JNIEnv *, jclass, jobject, jint, jlong); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: trainBinaryIndex + * Signature: (Ljava/util/Map;IJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex + (JNIEnv *, jclass, jobject, jint, jlong); + /* * Class: org_opensearch_knn_jni_FaissService * Method: transferVectors diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 9abb2357f6..92393245ee 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); +// Train a binary index with data provided +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * faiss::write_index(&idMap, indexPathCpp.c_str()); } +void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Read data set + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0)); + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, reinterpret_cast(inputVectors->data()), idVector.data()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index_binary(&idMap, indexPathCpp.c_str()); +} + jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); @@ -624,6 +697,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti return ret; } +jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Train index if needed + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + if(!indexWriter->is_trained) { + InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Now that indexWriter is trained, we just load the bytes into an array and return + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Wrap in smart pointer + std::unique_ptr jbytesBuffer; + jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); + int c = 0; + for (auto b : vectorIoWriter.data) { + jbytesBuffer[c++] = (jbyte) b; + } + + jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + return ret; +} + faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { if (spaceType == knn_jni::L2) { return faiss::METRIC_L2; @@ -682,6 +806,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + indexIvf->make_direct_map(); + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 6e447b0347..2394e2951f 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex return nullptr; } +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls, + jobject parametersJ, + jint dimensionJ, + jlong trainVectorsPointerJ) +{ + try { + return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ, jobjectArray vectorsJ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index a4256c2555..11e5b52be4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.ChecksumIndexInput; @@ -112,10 +111,22 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { } private VectorTransfer getVectorTransfer(FieldInfo field) { - if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + boolean isBinary = false; + + // Check if the field has a model ID and retrieve the model's vector data type + if (field.attributes().containsKey(MODEL_ID)) { + Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID)); + isBinary = model.getModelMetadata().getVectorDataType() == VectorDataType.BINARY; + } else if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + isBinary = true; + } + + // Return the appropriate VectorTransfer instance based on the vector data type + if (isBinary) { return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } else { + return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) @@ -154,7 +165,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); + indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } @@ -188,18 +199,25 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = ImmutableMap.of( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + model.getModelMetadata().getMethodComponentContext().getName().toUpperCase() + ); + } + AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, - model, + model.getModelBlob(), parameters, knnEngine ); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 69cdbfcd72..cbd9d9e2cb 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -562,7 +562,8 @@ protected void parseCreateField(ParseContext context) throws IOException { context, fieldType().getDimension(), fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()) + getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getVectorDataType() ); } @@ -605,8 +606,13 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) - throws IOException { + protected void parseCreateField( + ParseContext context, + int dimension, + SpaceType spaceType, + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType + ) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 5548712790..adaaef28e6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -62,6 +62,12 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); + parseCreateField( + context, + modelMetadata.getDimension(), + modelMetadata.getSpaceType(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getVectorDataType() + ); } } diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 711c206f50..4e39c1af18 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -305,7 +305,7 @@ public class Faiss extends NativeLibrary { return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; }) .build() - ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.HAMMING_BIT).build() ); final static Faiss INSTANCE = new Faiss( diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0bc6c5edbb..37edcd3aef 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -292,6 +292,7 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); + put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType()); MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index f3a5506cdb..90df1cb559 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -26,6 +26,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -48,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String timestamp; final private String description; final private String trainingNodeAssignment; + final private VectorDataType vectorDataType; private MethodComponentContext methodComponentContext; private String error; @@ -81,6 +83,12 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.methodComponentContext = MethodComponentContext.EMPTY; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.METHOD_PARAMETER)) { + this.vectorDataType = VectorDataType.get(in.readOptionalString()); + } else { + this.vectorDataType = VectorDataType.FLOAT; + } } /** @@ -105,7 +113,8 @@ public ModelMetadata( String description, String error, String trainingNodeAssignment, - MethodComponentContext methodComponentContext + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -128,6 +137,7 @@ public ModelMetadata( this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); + this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); } /** @@ -211,6 +221,10 @@ public MethodComponentContext getMethodComponentContext() { return methodComponentContext; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * setter for model's state * @@ -241,7 +255,8 @@ public String toString() { description, error, trainingNodeAssignment, - methodComponentContext.toClusterStateString() + methodComponentContext.toClusterStateString(), + vectorDataType.getValue() ); } @@ -259,6 +274,7 @@ public boolean equals(Object obj) { equalsBuilder.append(getTimestamp(), other.getTimestamp()); equalsBuilder.append(getDescription(), other.getDescription()); equalsBuilder.append(getError(), other.getError()); + equalsBuilder.append(getVectorDataType(), other.getVectorDataType()); return equalsBuilder.isEquals(); } @@ -273,6 +289,7 @@ public int hashCode() { .append(getDescription()) .append(getError()) .append(getMethodComponentContext()) + .append(getVectorDataType()) .toHashCode(); } @@ -288,7 +305,7 @@ public static ModelMetadata fromString(String modelMetadataString) { // Training node assignment was added as a field in Version 2.12.0 // Because models can be created on older versions and the cluster can be upgraded after, // we need to accept model metadata arrays both with and without the training node assignment. - if (modelMetadataArray.length == 7) { + if (modelMetadataArray.length == 8) { log.debug( "Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context." ); @@ -299,6 +316,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String timestamp = modelMetadataArray[4]; String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]); return new ModelMetadata( knnEngine, spaceType, @@ -308,9 +326,10 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + vectorDataType ); - } else if (modelMetadataArray.length == 8) { + } else if (modelMetadataArray.length == 9) { log.debug("Model metadata contains training node assignment. Assuming empty method component context."); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); @@ -320,6 +339,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[10]); return new ModelMetadata( knnEngine, spaceType, @@ -329,9 +349,10 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + vectorDataType ); - } else if (modelMetadataArray.length == 9) { + } else if (modelMetadataArray.length == 10) { log.debug("Model metadata contains training node assignment and method context"); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); @@ -342,6 +363,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]); + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]); return new ModelMetadata( knnEngine, spaceType, @@ -351,7 +373,8 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - methodComponentContext + methodComponentContext, + vectorDataType ); } else { throw new IllegalArgumentException( @@ -387,6 +410,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); + Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -416,7 +440,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(description), objectToString(error), objectToString(trainingNodeAssignment), - (MethodComponentContext) methodComponentContext + (MethodComponentContext) methodComponentContext, + VectorDataType.get(objectToString(vectorDataType)) ); return modelMetadata; } @@ -436,6 +461,9 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { getMethodComponentContext().writeTo(out); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.METHOD_PARAMETER)) { + out.writeOptionalString(vectorDataType.getValue()); + } } @Override @@ -456,6 +484,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws getMethodComponentContext().toXContent(builder, params); builder.endObject(); } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.METHOD_PARAMETER)) { + builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 21de907657..1f23f6fcdd 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -96,6 +96,25 @@ public static native void createIndexFromTemplate( Map parameters ); + /** + * Create a binary index for the native library with a provided template index + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters additional build time parameters + */ + public static native void createBinaryIndexFromTemplate( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -249,6 +268,16 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( */ public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** + * Train an empty binary index + * + * @param indexParameters parameters used to build index + * @param dimension dimension for the index + * @param trainVectorsPointer pointer to where training vectors are stored in native memory + * @return bytes array of trained template index + */ + public static native byte[] trainBinaryIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** *

* The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 02483f1a73..9958736bc8 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -85,8 +85,13 @@ public static void createIndexFromTemplate( KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); - return; + if (faissUtil.isBinaryIndex(parameters)) { + FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } else { + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } } throw new IllegalArgumentException( @@ -310,7 +315,11 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE */ public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) { if (KNNEngine.FAISS == knnEngine) { - return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + if (faissUtil.isBinaryIndex(indexParameters)) { + return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer); + } else { + return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + } } throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index fb8ccc4cec..eec2540af9 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -17,6 +17,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -30,16 +31,7 @@ import java.util.Locale; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.MAX_VECTOR_COUNT_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.MODELS; -import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PREFERENCE_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.SEARCH_SIZE_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.*; /** * Rest Handler for model training api endpoint. @@ -83,6 +75,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr String trainingIndex = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String trainingField = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String description = (String) DEFAULT_NOT_SET_OBJECT_VALUE; + VectorDataType vectorDataType = (VectorDataType) DEFAULT_NOT_SET_OBJECT_VALUE; int dimension = DEFAULT_NOT_SET_INT_VALUE; int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE; @@ -110,6 +103,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); ModelUtil.blockCommasInModelDescription(description); + } else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) { + vectorDataType = VectorDataType.get(parser.text()); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } @@ -126,6 +121,10 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr description = ""; } + if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) { + vectorDataType = VectorDataType.FLOAT; + } + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, knnMethodContext, @@ -133,7 +132,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingIndex, trainingField, preferredNodeId, - description + description, + vectorDataType ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 5f3913ac53..ebb286659d 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -21,6 +21,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.training.VectorSpaceInfo; @@ -41,6 +42,7 @@ public class TrainingModelRequest extends ActionRequest { private final String trainingField; private final String preferredNodeId; private final String description; + private final VectorDataType vectorDataType; private int maximumVectorCount; private int searchSize; @@ -65,7 +67,8 @@ public TrainingModelRequest( String trainingIndex, String trainingField, String preferredNodeId, - String description + String description, + VectorDataType vectorDataType ) { super(); this.modelId = modelId; @@ -75,6 +78,7 @@ public TrainingModelRequest( this.trainingField = trainingField; this.preferredNodeId = preferredNodeId; this.description = description; + this.vectorDataType = vectorDataType; // Set these as defaults initially. If call wants to override them, they can use the setters. this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index @@ -103,6 +107,7 @@ public TrainingModelRequest(StreamInput in) throws IOException { this.maximumVectorCount = in.readInt(); this.searchSize = in.readInt(); this.trainingDataSizeInKB = in.readInt(); + this.vectorDataType = VectorDataType.get(in.readOptionalString()); } /** @@ -213,6 +218,10 @@ public int getSearchSize() { return searchSize; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * Setter for search size. * diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 33b420e2c0..58ac41b313 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -68,7 +68,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener modelAnonymousEntryContext, request.getDimension(), request.getDescription(), - clusterService.localNode().getEphemeralId() + clusterService.localNode().getEphemeralId(), + request.getVectorDataType() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index aa2786c0a2..493d5e68ea 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -17,6 +17,7 @@ import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -32,6 +33,8 @@ import java.util.Map; import java.util.Objects; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + /** * Encapsulates all information required to generate and train a model. */ @@ -66,7 +69,8 @@ public TrainingJob( NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, int dimension, String description, - String nodeAssignment + String nodeAssignment, + VectorDataType vectorDataType ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -84,7 +88,8 @@ public TrainingJob( description, "", nodeAssignment, - knnMethodContext.getMethodComponentContext() + knnMethodContext.getMethodComponentContext(), + vectorDataType ), null, this.modelId @@ -182,6 +187,13 @@ public void run() { KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + if (VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { + trainParameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + trainParameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + } + byte[] modelBlob = JNIService.trainIndex( trainParameters, model.getModelMetadata().getDimension(),