From 735ccda9f679f632e43ca65e18472227fff0f45f Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Tue, 25 Jun 2024 13:16:26 -0700 Subject: [PATCH] Add binary format support with IVF method in Faiss Engine Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + jni/include/faiss_wrapper.h | 19 +++ .../org_opensearch_knn_jni_FaissService.h | 16 +++ jni/src/faiss_wrapper.cpp | 133 ++++++++++++++++++ .../org_opensearch_knn_jni_FaissService.cpp | 28 ++++ .../KNN80Codec/KNN80DocValuesConsumer.java | 38 +++-- .../index/mapper/KNNVectorFieldMapper.java | 12 +- .../knn/index/mapper/ModelFieldMapper.java | 8 +- .../knn/index/query/KNNQueryBuilder.java | 1 + .../opensearch/knn/index/query/KNNWeight.java | 11 +- .../org/opensearch/knn/index/util/Faiss.java | 2 +- .../knn/index/util/ModelInfoExtractor.java | 32 +++++ .../org/opensearch/knn/indices/ModelDao.java | 1 + .../opensearch/knn/indices/ModelMetadata.java | 49 +++++-- .../org/opensearch/knn/jni/FaissService.java | 29 ++++ .../org/opensearch/knn/jni/JNIService.java | 15 +- .../plugin/rest/RestTrainModelHandler.java | 22 +-- .../transport/TrainingModelRequest.java | 12 +- .../TrainingModelTransportAction.java | 3 +- .../opensearch/knn/training/TrainingJob.java | 16 ++- .../opensearch/knn/KNNSingleNodeTestCase.java | 3 +- .../index/KNNCreateIndexFromModelTests.java | 3 +- .../KNN80DocValuesConsumerTests.java | 3 +- .../knn/index/codec/KNNCodecTestCase.java | 9 +- .../mapper/KNNVectorFieldMapperTests.java | 18 ++- .../knn/index/query/KNNQueryBuilderTests.java | 3 + .../knn/index/query/KNNWeightTests.java | 11 +- .../knn/indices/ModelCacheTests.java | 38 +++-- .../opensearch/knn/indices/ModelDaoTests.java | 43 ++++-- .../knn/indices/ModelMetadataTests.java | 122 +++++++++++----- .../opensearch/knn/indices/ModelTests.java | 110 ++++++++++++--- .../transport/GetModelResponseTests.java | 8 +- ...oveModelFromCacheTransportActionTests.java | 4 +- ...TrainingJobRouterTransportActionTests.java | 4 +- .../transport/TrainingModelRequestTests.java | 48 +++++-- .../TrainingModelTransportActionTests.java | 4 +- ...ateModelGraveyardTransportActionTests.java | 4 +- .../UpdateModelMetadataRequestTests.java | 10 +- ...dateModelMetadataTransportActionTests.java | 4 +- .../knn/training/TrainingJobTests.java | 28 ++-- 40 files changed, 737 insertions(+), 188 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/util/ModelInfoExtractor.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ec910ccd1..f2d993f35d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783) * Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790) * Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781) +* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2b9bc2c767..d25bf8f7c4 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -29,6 +29,12 @@ namespace knn_jni { jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index + // based off of the template index passed in. The index is serialized to indexPathJ. + void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jobject parametersJ); + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index @@ -80,6 +86,12 @@ namespace knn_jni { jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Execute a query against the binary index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer, jboolean isBinaryIndexJ); @@ -96,6 +108,13 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + // Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with + // the vector of floats located at trainVectorsPointerJ. + // + // Return the serialized representation + jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 7cc071ff38..025fb12e88 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndexFromTemplate + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + */ + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex (JNIEnv *, jclass, jobject, jint, jlong); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: trainBinaryIndex + * Signature: (Ljava/util/Map;IJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex + (JNIEnv *, jclass, jobject, jint, jlong); + /* * Class: org_opensearch_knn_jni_FaissService * Method: transferVectors diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 9abb2357f6..92393245ee 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); +// Train a binary index with data provided +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * faiss::write_index(&idMap, indexPathCpp.c_str()); } +void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Read data set + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0)); + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, reinterpret_cast(inputVectors->data()), idVector.data()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index_binary(&idMap, indexPathCpp.c_str()); +} + jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); @@ -624,6 +697,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti return ret; } +jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Train index if needed + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + if(!indexWriter->is_trained) { + InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Now that indexWriter is trained, we just load the bytes into an array and return + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Wrap in smart pointer + std::unique_ptr jbytesBuffer; + jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); + int c = 0; + for (auto b : vectorIoWriter.data) { + jbytesBuffer[c++] = (jbyte) b; + } + + jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + return ret; +} + faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { if (spaceType == knn_jni::L2) { return faiss::METRIC_L2; @@ -682,6 +806,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + indexIvf->make_direct_map(); + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 6e447b0347..2394e2951f 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex return nullptr; } +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls, + jobject parametersJ, + jint dimensionJ, + jlong trainVectorsPointerJ) +{ + try { + return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ, jobjectArray vectorsJ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 349ad2f4d4..248ef026ab 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.ChecksumIndexInput; @@ -112,10 +111,22 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { } private VectorTransfer getVectorTransfer(FieldInfo field) { - if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + boolean isBinary = false; + + // Check if the field has a model ID and retrieve the model's vector data type + if (field.attributes().containsKey(MODEL_ID)) { + Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID)); + isBinary = model.getModelMetadata().getVectorDataType() == VectorDataType.BINARY; + } else if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + isBinary = true; + } + + // Return the appropriate VectorTransfer instance based on the vector data type + if (isBinary) { return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } else { + return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) @@ -154,7 +165,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); + indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } @@ -188,18 +199,25 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = ImmutableMap.of( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + model.getModelMetadata().getMethodComponentContext().getName().toUpperCase() + ); + } + AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, - model, + model.getModelBlob(), parameters, knnEngine ); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 793cb0bfc2..d78fc873a2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -553,7 +553,8 @@ protected void parseCreateField(ParseContext context) throws IOException { context, fieldType().getDimension(), fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()) + getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getVectorDataType() ); } @@ -596,8 +597,13 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) - throws IOException { + protected void parseCreateField( + ParseContext context, + int dimension, + SpaceType spaceType, + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType + ) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 5548712790..adaaef28e6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -62,6 +62,12 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); + parseCreateField( + context, + modelMetadata.getDimension(), + modelMetadata.getSpaceType(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getVectorDataType() + ); } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 80ee5e32c7..d0de166360 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -504,6 +504,7 @@ protected Query doToQuery(QueryShardContext context) { knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); methodComponentContext = modelMetadata.getMethodComponentContext(); + vectorDataType = modelMetadata.getVectorDataType(); } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 2c450ad8ad..0ad78d5e6e 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -43,6 +43,7 @@ import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator; import org.opensearch.knn.index.util.FieldInfoExtractor; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.index.util.ModelInfoExtractor; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -213,6 +214,7 @@ private Map doANNSearch(final LeafReaderContext context, final B KNNEngine knnEngine; SpaceType spaceType; + String indexDescription; // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's // metadata. @@ -225,11 +227,13 @@ private Map doANNSearch(final LeafReaderContext context, final B knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); + indexDescription = ModelInfoExtractor.getIndexDescription(modelMetadata); } else { String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); knnEngine = KNNEngine.getEngine(engineName); String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); spaceType = SpaceType.getSpace(spaceTypeName); + indexDescription = FieldInfoExtractor.getIndexDescription(fieldInfo); } /* @@ -261,12 +265,7 @@ private Map doANNSearch(final LeafReaderContext context, final B new NativeMemoryEntryContext.IndexEntryContext( indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading( - spaceType, - knnEngine, - knnQuery.getIndexName(), - FieldInfoExtractor.getIndexDescription(fieldInfo) - ), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(), indexDescription), knnQuery.getIndexName(), modelId ), diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 711c206f50..4e39c1af18 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -305,7 +305,7 @@ public class Faiss extends NativeLibrary { return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; }) .build() - ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.HAMMING_BIT).build() ); final static Faiss INSTANCE = new Faiss( diff --git a/src/main/java/org/opensearch/knn/index/util/ModelInfoExtractor.java b/src/main/java/org/opensearch/knn/index/util/ModelInfoExtractor.java new file mode 100644 index 0000000000..fcbcfc2b69 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/ModelInfoExtractor.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.util; + +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.indices.ModelMetadata; + +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + +public class ModelInfoExtractor { + public static String getIndexDescription(ModelMetadata modelMetadata) { + MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); + VectorDataType vectorDataType = modelMetadata.getVectorDataType(); + String indexDescription = methodComponentContext.getName().toUpperCase(); + + if (VectorDataType.BINARY.equals(vectorDataType)) { + indexDescription = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + indexDescription; + } + + return indexDescription; + } +} diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0bc6c5edbb..37edcd3aef 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -292,6 +292,7 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); + put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType()); MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index f3a5506cdb..f844842029 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -26,6 +26,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -48,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String timestamp; final private String description; final private String trainingNodeAssignment; + final private VectorDataType vectorDataType; private MethodComponentContext methodComponentContext; private String error; @@ -81,6 +83,12 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.methodComponentContext = MethodComponentContext.EMPTY; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.METHOD_PARAMETER)) { + this.vectorDataType = VectorDataType.get(in.readOptionalString()); + } else { + this.vectorDataType = VectorDataType.FLOAT; + } } /** @@ -105,7 +113,8 @@ public ModelMetadata( String description, String error, String trainingNodeAssignment, - MethodComponentContext methodComponentContext + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -128,6 +137,7 @@ public ModelMetadata( this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); + this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); } /** @@ -211,6 +221,10 @@ public MethodComponentContext getMethodComponentContext() { return methodComponentContext; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * setter for model's state * @@ -241,7 +255,8 @@ public String toString() { description, error, trainingNodeAssignment, - methodComponentContext.toClusterStateString() + methodComponentContext.toClusterStateString(), + vectorDataType.getValue() ); } @@ -259,6 +274,7 @@ public boolean equals(Object obj) { equalsBuilder.append(getTimestamp(), other.getTimestamp()); equalsBuilder.append(getDescription(), other.getDescription()); equalsBuilder.append(getError(), other.getError()); + equalsBuilder.append(getVectorDataType(), other.getVectorDataType()); return equalsBuilder.isEquals(); } @@ -273,6 +289,7 @@ public int hashCode() { .append(getDescription()) .append(getError()) .append(getMethodComponentContext()) + .append(getVectorDataType()) .toHashCode(); } @@ -288,7 +305,7 @@ public static ModelMetadata fromString(String modelMetadataString) { // Training node assignment was added as a field in Version 2.12.0 // Because models can be created on older versions and the cluster can be upgraded after, // we need to accept model metadata arrays both with and without the training node assignment. - if (modelMetadataArray.length == 7) { + if (modelMetadataArray.length == 8) { log.debug( "Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context." ); @@ -299,6 +316,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String timestamp = modelMetadataArray[4]; String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[7]); return new ModelMetadata( knnEngine, spaceType, @@ -308,9 +326,10 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + vectorDataType ); - } else if (modelMetadataArray.length == 8) { + } else if (modelMetadataArray.length == 9) { log.debug("Model metadata contains training node assignment. Assuming empty method component context."); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); @@ -320,6 +339,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[8]); return new ModelMetadata( knnEngine, spaceType, @@ -329,9 +349,10 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + vectorDataType ); - } else if (modelMetadataArray.length == 9) { + } else if (modelMetadataArray.length == 10) { log.debug("Model metadata contains training node assignment and method context"); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); @@ -342,6 +363,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]); + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]); return new ModelMetadata( knnEngine, spaceType, @@ -351,7 +373,8 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - methodComponentContext + methodComponentContext, + vectorDataType ); } else { throw new IllegalArgumentException( @@ -387,6 +410,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); + Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -416,7 +440,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(description), objectToString(error), objectToString(trainingNodeAssignment), - (MethodComponentContext) methodComponentContext + (MethodComponentContext) methodComponentContext, + VectorDataType.get(objectToString(vectorDataType)) ); return modelMetadata; } @@ -436,6 +461,9 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { getMethodComponentContext().writeTo(out); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.METHOD_PARAMETER)) { + out.writeOptionalString(vectorDataType.getValue()); + } } @Override @@ -456,6 +484,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws getMethodComponentContext().toXContent(builder, params); builder.endObject(); } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.METHOD_PARAMETER)) { + builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 21de907657..1f23f6fcdd 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -96,6 +96,25 @@ public static native void createIndexFromTemplate( Map parameters ); + /** + * Create a binary index for the native library with a provided template index + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters additional build time parameters + */ + public static native void createBinaryIndexFromTemplate( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -249,6 +268,16 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( */ public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** + * Train an empty binary index + * + * @param indexParameters parameters used to build index + * @param dimension dimension for the index + * @param trainVectorsPointer pointer to where training vectors are stored in native memory + * @return bytes array of trained template index + */ + public static native byte[] trainBinaryIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** *

* The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index cefd0af53e..6315f03a4d 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -83,8 +83,13 @@ public static void createIndexFromTemplate( KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); - return; + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } else { + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } } throw new IllegalArgumentException( @@ -308,7 +313,11 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE */ public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) { if (KNNEngine.FAISS == knnEngine) { - return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + if (IndexUtil.isBinaryIndex(knnEngine, indexParameters)) { + return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer); + } else { + return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + } } throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index fb8ccc4cec..eec2540af9 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -17,6 +17,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -30,16 +31,7 @@ import java.util.Locale; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.MAX_VECTOR_COUNT_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.MODELS; -import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PREFERENCE_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.SEARCH_SIZE_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.*; /** * Rest Handler for model training api endpoint. @@ -83,6 +75,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr String trainingIndex = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String trainingField = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String description = (String) DEFAULT_NOT_SET_OBJECT_VALUE; + VectorDataType vectorDataType = (VectorDataType) DEFAULT_NOT_SET_OBJECT_VALUE; int dimension = DEFAULT_NOT_SET_INT_VALUE; int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE; @@ -110,6 +103,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); ModelUtil.blockCommasInModelDescription(description); + } else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) { + vectorDataType = VectorDataType.get(parser.text()); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } @@ -126,6 +121,10 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr description = ""; } + if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) { + vectorDataType = VectorDataType.FLOAT; + } + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, knnMethodContext, @@ -133,7 +132,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingIndex, trainingField, preferredNodeId, - description + description, + vectorDataType ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 5f3913ac53..b2eb8af416 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -21,6 +21,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.training.VectorSpaceInfo; @@ -41,6 +42,7 @@ public class TrainingModelRequest extends ActionRequest { private final String trainingField; private final String preferredNodeId; private final String description; + private final VectorDataType vectorDataType; private int maximumVectorCount; private int searchSize; @@ -65,7 +67,8 @@ public TrainingModelRequest( String trainingIndex, String trainingField, String preferredNodeId, - String description + String description, + VectorDataType vectorDataType ) { super(); this.modelId = modelId; @@ -75,6 +78,7 @@ public TrainingModelRequest( this.trainingField = trainingField; this.preferredNodeId = preferredNodeId; this.description = description; + this.vectorDataType = vectorDataType; // Set these as defaults initially. If call wants to override them, they can use the setters. this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index @@ -103,6 +107,7 @@ public TrainingModelRequest(StreamInput in) throws IOException { this.maximumVectorCount = in.readInt(); this.searchSize = in.readInt(); this.trainingDataSizeInKB = in.readInt(); + this.vectorDataType = VectorDataType.get(in.readOptionalString()); } /** @@ -213,6 +218,10 @@ public int getSearchSize() { return searchSize; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * Setter for search size. * @@ -336,5 +345,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(this.maximumVectorCount); out.writeInt(this.searchSize); out.writeInt(this.trainingDataSizeInKB); + out.writeOptionalString(this.vectorDataType.getValue()); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 33b420e2c0..58ac41b313 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -68,7 +68,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener modelAnonymousEntryContext, request.getDimension(), request.getDescription(), - clusterService.localNode().getEphemeralId() + clusterService.localNode().getEphemeralId(), + request.getVectorDataType() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index aa2786c0a2..493d5e68ea 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -17,6 +17,7 @@ import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -32,6 +33,8 @@ import java.util.Map; import java.util.Objects; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + /** * Encapsulates all information required to generate and train a model. */ @@ -66,7 +69,8 @@ public TrainingJob( NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, int dimension, String description, - String nodeAssignment + String nodeAssignment, + VectorDataType vectorDataType ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -84,7 +88,8 @@ public TrainingJob( description, "", nodeAssignment, - knnMethodContext.getMethodComponentContext() + knnMethodContext.getMethodComponentContext(), + vectorDataType ), null, this.modelId @@ -182,6 +187,13 @@ public void run() { KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + if (VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { + trainParameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + trainParameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + } + byte[] modelBlob = JNIService.trainIndex( trainParameters, model.getModelMetadata().getDimension(), diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index f9c0161d65..7a8c0515d0 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -201,7 +201,8 @@ protected void writeModelToModelSystemIndex(Model model) throws IOException, Exe .field(MODEL_STATE, modelMetadata.getState().getName()) .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) - .field(MODEL_ERROR, modelMetadata.getError()); + .field(MODEL_ERROR, modelMetadata.getError()) + .field(VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType().getValue()); if (model.getModelBlob() != null) { builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob())); diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index 11a8bdb15d..e9b78e7ec0 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -63,7 +63,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException "", "", "test-node", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 4e3231894e..68632d54d7 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -424,7 +424,8 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio "Empty description", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBytes, modelId diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index b82bc85e05..a5f5006f39 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -19,16 +19,12 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.*; import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNWeight; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorField; import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; @@ -213,7 +209,8 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 0ba2b97bc2..8ac1713826 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -173,7 +173,8 @@ public void testBuilder_build_fromModel() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); @@ -674,7 +675,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); @@ -745,7 +747,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.FLOAT ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField @@ -789,7 +792,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.FLOAT ); // Document should have 1 field: one for KnnVectorField @@ -824,7 +828,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.BYTE ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField @@ -867,7 +872,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.BYTE ); // Document should have 1 field: one for KnnByteVectorField diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 06e3700269..0ab5a34cc3 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -903,6 +903,7 @@ public void testDoToQuery_FromModel() { when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -940,6 +941,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -975,6 +977,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7da80f2fe1..ff3526dde0 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -38,6 +38,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNNCodecVersion; @@ -62,6 +63,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static java.util.Collections.emptyMap; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -75,12 +77,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.*; public class KNNWeightTests extends KNNTestCase { private static final String FIELD_NAME = "target_field"; @@ -199,6 +196,8 @@ public void testQueryScoreForFaissWithModel() { when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(spaceType); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata); KNNWeight.initialize(modelDao); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index 3a5255cd3e..7a42e8a25b 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.time.ZoneOffset; @@ -45,7 +46,8 @@ public void testGet_normal() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), "hello".getBytes(), modelId @@ -82,7 +84,8 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[BYTES_PER_KILOBYTES + 1], modelId @@ -140,7 +143,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size1], modelId1 @@ -156,7 +160,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size2], modelId2 @@ -200,7 +205,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size1], modelId1 @@ -216,8 +222,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY - + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size2], modelId2 @@ -266,7 +272,8 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), "hello".getBytes(), modelId @@ -312,7 +319,8 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize], modelId @@ -381,7 +389,8 @@ public void testContains() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize1], modelId1 @@ -423,7 +432,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize1], modelId1 @@ -441,7 +451,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize2], modelId2 @@ -487,7 +498,8 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[BYTES_PER_KILOBYTES * 2], modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 75c5233321..e3619975e8 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -35,6 +35,7 @@ import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; @@ -139,7 +140,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -159,7 +161,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -187,7 +190,8 @@ public void testPut_withId() throws InterruptedException, IOException { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.FLOAT ), modelBlob, modelId @@ -248,7 +252,8 @@ public void testPut_withoutModel() throws InterruptedException, IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -310,7 +315,8 @@ public void testPut_invalid_badState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, "any-id" @@ -347,7 +353,8 @@ public void testUpdate() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, modelId @@ -386,7 +393,8 @@ public void testUpdate() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -437,7 +445,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -456,7 +465,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, modelId @@ -493,7 +503,8 @@ public void testGetMetadata() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -570,7 +581,8 @@ public void testDelete() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -604,7 +616,8 @@ public void testDelete() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId1 @@ -672,7 +685,8 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -714,7 +728,8 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 74715671f3..cd36496f8c 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -45,7 +46,8 @@ public void testStreams() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -67,7 +69,8 @@ public void testGetKnnEngine() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); @@ -84,7 +87,8 @@ public void testGetSpaceType() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(spaceType, modelMetadata.getSpaceType()); @@ -101,7 +105,8 @@ public void testGetDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(dimension, modelMetadata.getDimension()); @@ -118,7 +123,8 @@ public void testGetState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(modelState, modelMetadata.getState()); @@ -135,7 +141,8 @@ public void testGetTimestamp() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(timeValue, modelMetadata.getTimestamp()); @@ -152,7 +159,8 @@ public void testDescription() { description, "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(description, modelMetadata.getDescription()); @@ -169,7 +177,8 @@ public void testGetError() { "", error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(error, modelMetadata.getError()); @@ -186,7 +195,8 @@ public void testSetState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(modelState, modelMetadata.getState()); @@ -207,7 +217,8 @@ public void testSetError() { "", error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(error, modelMetadata.getError()); @@ -244,7 +255,9 @@ public void testToString() { + "," + nodeAssignment + "," - + methodComponentContext.toClusterStateString(); + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.FLOAT.getValue(); ModelMetadata modelMetadata = new ModelMetadata( knnEngine, @@ -255,7 +268,8 @@ public void testToString() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(expected, modelMetadata.toString()); @@ -275,7 +289,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -286,7 +301,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -298,7 +314,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -309,7 +326,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -320,7 +338,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -331,7 +350,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -342,7 +362,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -353,7 +374,8 @@ public void testEquals() { "diff descript", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -364,7 +386,8 @@ public void testEquals() { "", "diff error", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -376,7 +399,8 @@ public void testEquals() { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.FLOAT ); assertEquals(modelMetadata1, modelMetadata1); @@ -406,7 +430,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -417,7 +442,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -429,7 +455,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -440,7 +467,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -451,7 +479,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -462,7 +491,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -473,7 +503,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -484,7 +515,8 @@ public void testHashCode() { "diff descript", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -495,7 +527,8 @@ public void testHashCode() { "", "diff error", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -507,7 +540,8 @@ public void testHashCode() { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.FLOAT ); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); @@ -550,7 +584,9 @@ public void testFromString() { + "," + nodeAssignment + "," - + methodComponentContext.toClusterStateString(); + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.FLOAT.getValue(); String stringRep2 = knnEngine.getName() + "," @@ -564,7 +600,9 @@ public void testFromString() { + "," + description + "," - + error; + + error + + "," + + VectorDataType.FLOAT.getValue(); ModelMetadata expected1 = new ModelMetadata( knnEngine, @@ -575,7 +613,8 @@ public void testFromString() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata expected2 = new ModelMetadata( @@ -587,7 +626,8 @@ public void testFromString() { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); @@ -620,7 +660,8 @@ public void testFromResponseMap() throws IOException { description, error, nodeAssignment, - methodComponentContext + methodComponentContext, + VectorDataType.FLOAT ); ModelMetadata expected2 = new ModelMetadata( @@ -632,7 +673,8 @@ public void testFromResponseMap() throws IOException { description, error, "", - emptyMethodComponentContext + emptyMethodComponentContext, + VectorDataType.FLOAT ); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); @@ -643,6 +685,7 @@ public void testFromResponseMap() throws IOException { metadataAsMap.put(KNNConstants.MODEL_DESCRIPTION, description); metadataAsMap.put(KNNConstants.MODEL_ERROR, error); metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); + metadataAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); @@ -678,7 +721,8 @@ public void testBlockCommasInDescription() { description, error, nodeAssignment, - methodComponentContext + methodComponentContext, + VectorDataType.FLOAT ) ); assertEquals("Model description cannot contain any commas: ','", e.getMessage()); diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index 13579acadd..773a10a2cd 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.time.ZoneOffset; @@ -41,7 +42,8 @@ public void testInvalidConstructor() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, "test-model" @@ -62,7 +64,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[16], "test-model" @@ -80,7 +83,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[16], "test-model" @@ -98,7 +102,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[16], "test-model" @@ -117,7 +122,8 @@ public void testGetModelMetadata() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); @@ -135,7 +141,8 @@ public void testGetModelBlob() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, "test-model" @@ -155,7 +162,8 @@ public void testGetLength() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size], "test-model" @@ -172,7 +180,8 @@ public void testGetLength() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, "test-model" @@ -192,7 +201,8 @@ public void testSetModelBlob() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), blob1, "test-model" @@ -209,17 +219,50 @@ public void testEquals() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-2" ); @@ -234,17 +277,50 @@ public void testHashCode() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-2" ); @@ -274,7 +350,8 @@ public void testModelFromSourceMap() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); @@ -287,6 +364,7 @@ public void testModelFromSourceMap() { modelAsMap.put(KNNConstants.MODEL_ERROR, error); modelAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER, "aGVsbG8="); + modelAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()); byte[] blob1 = "hello".getBytes(); Model expected = new Model(metadata, blob1, modelID); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index a6985e72a7..2106e31acc 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; @@ -43,7 +44,8 @@ private ModelMetadata getModelMetadata(ModelState state) { "test model", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); } @@ -68,7 +70,7 @@ public void testXContent() throws IOException { Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); @@ -84,7 +86,7 @@ public void testXContentWithNoModelBlob() throws IOException { Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index a2da83dadf..3b25bc0eb8 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; @@ -78,7 +79,8 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup "description", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[128], modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 56c50aca1c..bdca896ff7 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -24,6 +24,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.transport.TransportService; @@ -307,7 +308,8 @@ public void testTrainingIndexSize() { trainingIndexName, "training-field", null, - "description" + "description", + VectorDataType.FLOAT ); // Mock client to return the right number of docs diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index b39c486351..28e09a61e7 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -25,6 +25,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -61,7 +62,8 @@ public void testStreams() throws IOException { trainingIndex, trainingField, preferredNode, - description + description, + VectorDataType.FLOAT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -74,6 +76,7 @@ public void testStreams() throws IOException { assertEquals(original1.getTrainingIndex(), copy1.getTrainingIndex()); assertEquals(original1.getTrainingField(), copy1.getTrainingField()); assertEquals(original1.getPreferredNodeId(), copy1.getPreferredNodeId()); + assertEquals(original1.getVectorDataType(), copy1.getVectorDataType()); // Also, check when preferred node and model id and description are null TrainingModelRequest original2 = new TrainingModelRequest( @@ -83,7 +86,8 @@ public void testStreams() throws IOException { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); streamOutput = new BytesStreamOutput(); @@ -96,6 +100,7 @@ public void testStreams() throws IOException { assertEquals(original2.getTrainingIndex(), copy2.getTrainingIndex()); assertEquals(original2.getTrainingField(), copy2.getTrainingField()); assertEquals(original2.getPreferredNodeId(), copy2.getPreferredNodeId()); + assertEquals(original2.getVectorDataType(), copy2.getVectorDataType()); } public void testGetters() { @@ -117,7 +122,8 @@ public void testGetters() { trainingIndex, trainingField, preferredNode, - description + description, + VectorDataType.FLOAT ); trainingModelRequest.setMaximumVectorCount(maxVectorCount); @@ -156,7 +162,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -170,7 +177,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); @@ -211,7 +219,8 @@ public void testValidation_blocked_modelId() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return true to recognize that the modelId is in graveyard @@ -257,7 +266,8 @@ public void testValidation_invalid_invalidMethodContext() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -300,7 +310,8 @@ public void testValidation_invalid_trainingIndexDoesNotExist() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -346,7 +357,8 @@ public void testValidation_invalid_trainingFieldDoesNotExist() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -397,7 +409,8 @@ public void testValidation_invalid_trainingFieldNotKnnVector() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -452,7 +465,8 @@ public void testValidation_invalid_dimensionDoesNotMatch() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -509,7 +523,8 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { trainingIndex, trainingField, preferredNode, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -574,7 +589,8 @@ public void testValidation_invalid_descriptionToLong() { trainingIndex, trainingField, null, - description + description, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -618,7 +634,8 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -655,7 +672,8 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java index 950ce1fd08..221f50fe3c 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -72,7 +73,8 @@ public void testDoExecute() throws InterruptedException, ExecutionException, IOE trainingIndexName, trainingFieldName, null, - "test-detector" + "test-detector", + VectorDataType.FLOAT ); trainingModelRequest.setTrainingDataSizeInKB(estimateVectorSetSizeInKB(trainingDataCount, dimension)); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index 5be907ebd3..e0d7c521c1 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelGraveyard; @@ -210,7 +211,8 @@ public void testClusterManagerOperation_GetIndicesUsingModel() throws IOExceptio "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index 3719d124ac..2a016d98bd 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -42,7 +43,8 @@ public void testStreams() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -67,7 +69,8 @@ public void testValidate() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); @@ -107,7 +110,8 @@ public void testGetModelMetadata() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index ab0e4f506a..e16b720f0c 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -68,7 +69,8 @@ public void testClusterManagerOperation() throws InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); // Get update transport action diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 06b96c57c7..57ecb8323d 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -18,6 +18,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -67,7 +68,8 @@ public void testGetModelId() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), 10, "", - "test-node" + "test-node", + VectorDataType.FLOAT ); assertEquals(modelId, trainingJob.getModelId()); @@ -96,7 +98,8 @@ public void testGetModel() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, description, - nodeAssignment + nodeAssignment, + VectorDataType.FLOAT ); Model model = new Model( @@ -109,7 +112,8 @@ public void testGetModel() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, modelID @@ -183,8 +187,8 @@ public void testRun_success() throws IOException, ExecutionException { modelContext, dimension, "", - "test-node" - + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -262,8 +266,8 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept modelContext, dimension, "", - - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -330,8 +334,8 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce modelContext, dimension, "", - - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -397,7 +401,8 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, "", - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -470,7 +475,8 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { modelContext, dimension, "", - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run();