diff --git a/jni/include/commons.h b/jni/include/commons.h index d024393777..4cdaf28fc9 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -19,12 +19,19 @@ namespace knn_jni { * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location * will throw Exception. * + * append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector + * without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating + * and deallocating when the memory address needs to be reused. + * + * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called + * * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append whether to append or start from index 0 when called subsequently with the same address * @return memory address of std::vector where the data is stored. */ - jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** * This is utility function that can be used to store data in native memory. This function will allocate memory for @@ -33,12 +40,18 @@ namespace knn_jni { * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location * will throw Exception. * + * append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector + * without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating + * and deallocating when the memory address needs to be reused. + * + * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called + * * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D byte array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address of std::vector where the data is stored. */ - jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h index 89de76520e..03c0d023a8 100644 --- a/jni/include/org_opensearch_knn_jni_JNICommons.h +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -21,10 +21,10 @@ extern "C" { /* * Class: org_opensearch_knn_jni_JNICommons * Method: storeVectorData - * Signature: (J[[FJJ) + * Signature: (J[[FJJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData - (JNIEnv *, jclass, jlong, jobjectArray, jlong); + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); /* * Class: org_opensearch_knn_jni_JNICommons @@ -32,7 +32,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData * Signature: (J[[FJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData - (JNIEnv *, jclass, jlong, jobjectArray, jlong); + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); /* * Class: org_opensearch_knn_jni_JNICommons diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index 13f59194e3..f9764db736 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -18,7 +18,7 @@ #include "commons.h" jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, - jobjectArray dataJ, jlong initialCapacityJ) { + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector *vect; if ((long) memoryAddressJ == 0) { vect = new std::vector(); @@ -26,6 +26,11 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE } else { vect = reinterpret_cast*>(memoryAddressJ); } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect); @@ -33,7 +38,7 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE } jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, - jobjectArray dataJ, jlong initialCapacityJ) { + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector *vect; if ((long) memoryAddressJ == 0) { vect = new std::vector(); @@ -41,6 +46,11 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, } else { vect = reinterpret_cast*>(memoryAddressJ); } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 2b804a672f..6e7dd49127 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -75,7 +75,6 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &indexService); - delete reinterpret_cast*>(vectorsAddressJ); } catch (...) { // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! jniUtil.CatchCppExceptionAndThrowJava(env); diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp index 0bc2e46331..7432c44d3a 100644 --- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -38,11 +38,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls, -jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { try { - return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -50,11 +50,11 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) } JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls, -jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { try { - return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp index 98def8807b..d469fe268f 100644 --- a/jni/tests/commons_test.cpp +++ b/jni/tests/commons_test.cpp @@ -33,7 +33,7 @@ TEST(CommonsTests, BasicAssertions) { testing::NiceMock mockJNIUtil; jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0, - reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim)); + reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim), true); ASSERT_NE(memoryAddress, 0); auto *vect = reinterpret_cast*>(memoryAddress); ASSERT_EQ(vect->size(), data.size() * dim); @@ -48,12 +48,13 @@ TEST(CommonsTests, BasicAssertions) { } data2.push_back(vector); memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, - reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim)); + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim), true); ASSERT_NE(memoryAddress, 0); ASSERT_EQ(memoryAddress, oldMemoryAddress); vect = reinterpret_cast*>(memoryAddress); int currentIndex = 0; - ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + std::cout << vect->size() + "\n"; + ASSERT_EQ(vect->size(), totalNumberOfVector * dim); ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); // Validate if all vectors data are at correct location @@ -70,6 +71,113 @@ TEST(CommonsTests, BasicAssertions) { currentIndex++; } } + + // test append == true + std::vector> data3; + std::vector vecto3; + for(int j = 0 ; j < dim ; j ++) { + vecto3.push_back((float)j); + } + data3.push_back(vecto3); + memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data3), (jlong)(totalNumberOfVector * dim), false); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + + ASSERT_EQ(vect->size(), dim); //Since we just added 1 vector + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); //This is the initial capacity allocated + + currentIndex = 0; + for(auto & i : data3) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // Check that freeing vector data works + knn_jni::commons::freeVectorData(memoryAddress); +} + +TEST(StoreByteVectorTest, BasicAssertions) { + long dim = 3; + long totalNumberOfVector = 5; + std::vector> data; + for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) { + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((uint8_t)j); + } + data.push_back(vector); + } + JNIEnv *jniEnv = nullptr; + + testing::NiceMock mockJNIUtil; + + jlong memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, (jlong)0, + reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim), true); + ASSERT_NE(memoryAddress, 0); + auto *vect = reinterpret_cast*>(memoryAddress); + ASSERT_EQ(vect->size(), data.size() * dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Check by inserting more vectors at same memory location + jlong oldMemoryAddress = memoryAddress; + std::vector> data2; + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((uint8_t)j); + } + data2.push_back(vector); + memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim), true); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + int currentIndex = 0; + ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Validate if all vectors data are at correct location + for(auto & i : data) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + for(auto & i : data2) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // test append == true + std::vector> data3; + std::vector vecto3; + for(int j = 0 ; j < dim ; j ++) { + vecto3.push_back((uint8_t)j); + } + data3.push_back(vecto3); + memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data3), (jlong)(totalNumberOfVector * dim), false); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + + ASSERT_EQ(vect->size(), dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + currentIndex = 0; + for(auto & i : data3) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + // Check that freeing vector data works knn_jni::commons::freeVectorData(memoryAddress); } diff --git a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java new file mode 100644 index 0000000000..a7c8d585ce --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import lombok.experimental.UtilityClass; +import org.apache.lucene.index.FieldInfo; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.indices.ModelMetadata; + +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.indices.ModelUtil.getModelMetadata; + +@UtilityClass +public class FieldInfoExtractor { + + public static KNNEngine extractKNNEngine(final FieldInfo field) { + final ModelMetadata modelMetadata = getModelMetadata(field.attributes().get(MODEL_ID)); + if (modelMetadata != null) { + return modelMetadata.getKnnEngine(); + } + final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + return KNNEngine.getEngine(engineName); + } + + public static VectorDataType extractVectorDataType(final FieldInfo field) { + String vectorDataTypeString = field.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD); + if (vectorDataTypeString == null) { + final ModelMetadata modelMetadata = getModelMetadata(field.attributes().get(MODEL_ID)); + if (modelMetadata != null) { + VectorDataType vectorDataType = modelMetadata.getVectorDataType(); + vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue(); + } + } + return vectorDataTypeString != null ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT; + } +} diff --git a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java index fd9e5b6c23..9381f73e80 100644 --- a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java @@ -5,10 +5,12 @@ package org.opensearch.knn.common; -import java.util.Objects; import lombok.AccessLevel; import lombok.NoArgsConstructor; +import java.util.ArrayList; +import java.util.Objects; + @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorUtil { /** @@ -42,4 +44,18 @@ public static boolean isZeroVector(float[] vector) { } return true; } + + /** + * Creates an int overflow safe arraylist. If there is an overflow it will create a list with default initial size + * @param batchSize size to allocate + * @return an arrayList + */ + public static ArrayList createArrayList(long batchSize) { + try { + return new ArrayList<>(Math.toIntExact(batchSize)); + } catch (Exception exception) { + // No-op + } + return new ArrayList<>(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 69229036ed..8ed4dfeadf 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -8,8 +8,11 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; @@ -17,6 +20,7 @@ import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.function.Supplier; @@ -78,42 +82,47 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); - var params = knnMethodContext.getMethodComponentContext().getParameters(); + final KNNEngine engine = knnMethodContext.getKnnEngine(); + final Map params = knnMethodContext.getMethodComponentContext().getParameters(); - if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { - KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( - params, - defaultMaxConnections, - defaultBeamWidth - ); - if (knnScalarQuantizedVectorsFormatParams.validate(params)) { - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnScalarQuantizedVectorsFormatParams.getBeamWidth(), - LUCENE_SQ_CONFIDENCE_INTERVAL, - knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), - LUCENE_SQ_BITS, - knnScalarQuantizedVectorsFormatParams.getBits() + if (engine.equals(KNNEngine.LUCENE)) { + if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth ); - return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + if (knnScalarQuantizedVectorsFormatParams.validate(params)) { + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + LUCENE_SQ_CONFIDENCE_INTERVAL, + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + LUCENE_SQ_BITS, + knnScalarQuantizedVectorsFormatParams.getBits() + ); + return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + } } + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnVectorsFormatParams.getBeamWidth() + ); + return vectorsFormatSupplier.apply(knnVectorsFormatParams); } - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnVectorsFormatParams.getBeamWidth() - ); - return vectorsFormatSupplier.apply(knnVectorsFormatParams); + // All native engines to use NativeEngines990KnnVectorsFormat + return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index d2117a3bc8..63173c34d0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,11 +5,12 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; @@ -20,12 +21,12 @@ import org.apache.lucene.index.SegmentWriteState; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.stats.KNNGraphValue; import java.io.IOException; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; /** * This class writes the KNN docvalues to the segments @@ -49,7 +50,7 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th if (isKNNBinaryFieldRequired(field)) { StopWatch stopWatch = new StopWatch(); stopWatch.start(); - addKNNBinaryField(field, valuesProducer, false, true); + addKNNBinaryField(field, valuesProducer, false); stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); @@ -58,25 +59,26 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th } private boolean isKNNBinaryFieldRequired(FieldInfo field) { - final KNNEngine knnEngine = getKNNEngine(field); + final KNNEngine knnEngine = extractKNNEngine(field); log.debug(String.format("Read engine [%s] for field [%s]", knnEngine.getName(), field.getName())); return field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD) && KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine); } - private KNNEngine getKNNEngine(@NonNull FieldInfo field) { - final String modelId = field.attributes().get(MODEL_ID); - if (modelId != null) { - var model = ModelCache.getInstance().get(modelId); - return model.getModelMetadata().getKnnEngine(); + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) throws IOException { + final VectorDataType vectorDataType = extractVectorDataType(field); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field)); + + boolean indexWritten; + if (isMerge) { + indexWritten = NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues); + } else { + indexWritten = NativeIndexWriter.getWriter(field, state).refreshIndex(knnVectorValues); } - final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); - return KNNEngine.getEngine(engineName); - } - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) - throws IOException { - NativeIndexWriter.getWriter(field).createKNNIndex(field, valuesProducer, state, isMerge, isRefresh); + if (!indexWritten) { + log.warn("Index not written for for field [{}]", field.getName()); + } } /** @@ -95,7 +97,7 @@ public void merge(MergeState mergeState) { if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { StopWatch stopWatch = new StopWatch(); stopWatch.start(); - addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true, false); + addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true); stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index e4860af31d..657a6a31e4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -30,6 +30,7 @@ */ class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngineFieldVectorsWriter.class); + @Getter private final FieldInfo fieldInfo; /** * We are using a map here instead of list, because for sampler interface for quantization we have to advance the iterator diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index b81ec9789b..aef360e31d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -12,23 +12,33 @@ package org.opensearch.knn.index.codec.KNN990Codec; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; + /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. */ +@Log4j2 @RequiredArgsConstructor public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); @@ -46,8 +56,6 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { @Override public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOException { final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream); - // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative - // graph build support on the JNI layer. fields.add(newField); return flatVectorsWriter.addField(fieldInfo, newField); } @@ -62,14 +70,48 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { // simply write data in the flat file flatVectorsWriter.flush(maxDoc, sortMap); - // TODO: add code for creating Vector datastructures during lucene flush operation + for (final NativeEngineFieldVectorsWriter field : fields) { + final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo()); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + + // TODO: Extract quantization state here + boolean indexWritten = NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).refreshIndex(knnVectorValues); + if (!indexWritten) { + log.warn("Wasn't able to flush KNN index for field [{}]", field.getFieldInfo().getName()); + } + } } @Override public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // TODO: add code for creating Vector datastructures during merge operation + + // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + final KNNVectorValues knnVectorValues; + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + break; + case BYTE: + final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + break; + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + + // TODO: Extract Quantization state here + boolean indexWritten = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); + if (!indexWritten) { + log.warn("Wasn't able to merge KNN index for field [{}]", fieldInfo.getName()); + } } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/BulkVectorTransferIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/BulkVectorTransferIndexBuildStrategy.java new file mode 100644 index 0000000000..89e1457dfb --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/BulkVectorTransferIndexBuildStrategy.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapByteVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapFloatVectorTransfer; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; + +/** + * Transfers all vectors to offheap and then builds an index + */ +final class BulkVectorTransferIndexBuildStrategy implements NativeIndexBuildStrategy { + + private static BulkVectorTransferIndexBuildStrategy INSTANCE = new BulkVectorTransferIndexBuildStrategy(); + + public static BulkVectorTransferIndexBuildStrategy getInstance() { + return INSTANCE; + } + + private BulkVectorTransferIndexBuildStrategy() {} + + public boolean buildIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) { + + try (final VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), knnVectorValues)) { + vectorTransfer.transferBatch(); + + final Map params = indexInfo.getParameters(); + // Currently this is if else as there are only two cases, with more cases this will have to be made + // more maintainable + if (params.containsKey(MODEL_ID)) { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate( + vectorTransfer.getTransferredDocsIds(), + vectorTransfer.getVectorAddress(), + knnVectorValues.dimension(), + indexInfo.getIndexPath(), + (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } else { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex( + vectorTransfer.getTransferredDocsIds(), + vectorTransfer.getVectorAddress(), + knnVectorValues.dimension(), + indexInfo.getIndexPath(), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } + + } catch (Exception exception) { + throw new RuntimeException("Failed to build index", exception); + } + + return true; + } + + private VectorTransfer getVectorTransfer(VectorDataType vectorDataType, KNNVectorValues knnVectorValues) throws IOException { + switch (vectorDataType) { + case FLOAT: + return new OffHeapFloatVectorTransfer((KNNFloatVectorValues) knnVectorValues, knnVectorValues.totalLiveDocs()); + case BINARY: + case BYTE: + return new OffHeapByteVectorTransfer((KNNVectorValues) knnVectorValues, knnVectorValues.totalLiveDocs()); + default: + throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java new file mode 100644 index 0000000000..5c77568843 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapByteVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapFloatVectorTransfer; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; + +/** + * Iteratively builds the index. + */ +final class MemOptimizedNativeIndexBuildStrategy implements NativeIndexBuildStrategy { + + private static MemOptimizedNativeIndexBuildStrategy INSTANCE = new MemOptimizedNativeIndexBuildStrategy(); + + public static MemOptimizedNativeIndexBuildStrategy getInstance() { + return INSTANCE; + } + + private MemOptimizedNativeIndexBuildStrategy() {} + + public boolean buildIndex(BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + if (knnVectorValues.docId() == -1) { + // Iterating once so the dimension() does not throw error; + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + } + KNNEngine engine = indexInfo.getKnnEngine(); + Map indexParameters = indexInfo.getParameters(); + + // Initialize the index + long indexMemoryAddress = AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.initIndexFromScratch( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + indexParameters, + engine + ) + ); + + try (final VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), knnVectorValues)) { + + while (vectorTransfer.hasNext()) { + vectorTransfer.transferBatch(); + long vectorAddress = vectorTransfer.getVectorAddress(); + int[] docs = vectorTransfer.getTransferredDocsIds(); + + // Insert vectors + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex(docs, vectorAddress, knnVectorValues.dimension(), indexParameters, indexMemoryAddress, engine); + return null; + }); + } + + // Write vector + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexInfo.getIndexPath(), indexMemoryAddress, engine, indexParameters); + return null; + }); + + } catch (Exception exception) { + throw new RuntimeException("Failed to build index", exception); + } + + return true; + } + + // TODO: Will probably need a factory once quantization is added + private VectorTransfer getVectorTransfer(VectorDataType vectorDataType, KNNVectorValues knnVectorValues) throws IOException { + switch (vectorDataType) { + case FLOAT: + return new OffHeapFloatVectorTransfer((KNNFloatVectorValues) knnVectorValues); + case BINARY: + case BYTE: + return new OffHeapByteVectorTransfer((KNNVectorValues) knnVectorValues); + default: + throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java new file mode 100644 index 0000000000..91aaa49a27 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; + +/** + * Interface which dictates how the index needs to be built + */ +public interface NativeIndexBuildStrategy { + + boolean buildIndex(BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 8cd75c4d70..a1493d2f8d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -5,73 +5,60 @@ package org.opensearch.knn.index.codec.nativeindex; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.util.Map; - -import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import org.opensearch.knn.index.codec.transfer.VectorTransferByte; -import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; -import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.util.IndexUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.plugin.stats.KNNGraphValue; -import lombok.Builder; -import lombok.NonNull; -import lombok.Value; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.HashMap; +import java.util.Map; import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; +import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** - * Abstract class to build the KNN index and write it to disk + * Writes KNN Index for a field in a segment. This is intended to be used for native engines */ +@AllArgsConstructor @Log4j2 -public abstract class NativeIndexWriter { +public class NativeIndexWriter { private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; - /** - * Class that holds info about vectors - */ - @Builder - @Value - protected static class NativeVectorInfo { - private VectorDataType vectorDataType; - private int dimension; - } - - /** - * Class that holds info about the native index - */ - @Builder - @Value - protected static class NativeIndexInfo { - private FieldInfo fieldInfo; - private KNNEngine knnEngine; - private int numDocs; - private long arraySize; - private Map parameters; - private NativeVectorInfo vectorInfo; - private String indexPath; - } + private final SegmentWriteState state; + private final FieldInfo fieldInfo; + private final NativeIndexBuildStrategy indexBuilder; + // TODO: Add quantization state as a member variable /** * Gets the correct writer type from fieldInfo @@ -79,42 +66,25 @@ protected static class NativeIndexInfo { * @param fieldInfo * @return correct NativeIndexWriter to make index specified in fieldInfo */ - public static NativeIndexWriter getWriter(FieldInfo fieldInfo) { - final KNNEngine knnEngine = getKNNEngine(fieldInfo); - boolean fromScratch = !fieldInfo.attributes().containsKey(MODEL_ID); - boolean iterative = fromScratch && KNNEngine.FAISS == knnEngine; - if (fromScratch && iterative) { - return new NativeIndexWriterScratchIter(); - } else if (fromScratch) { - return new NativeIndexWriterScratch(); - } else { - return new NativeIndexWriterTemplate(); + public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { + // TODO: Fetch the quantization state here and pass it to NativeIndexWriter + + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); + boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + if (iterative) { + return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance()); } + return new NativeIndexWriter(state, fieldInfo, BulkVectorTransferIndexBuildStrategy.getInstance()); } - /** - * Method for creating a KNN index in the specified native library - * - * @param fieldInfo - * @param valuesProducer - * @param state - * @param isMerge - * @param isRefresh - * @throws IOException - */ - public void createKNNIndex( - FieldInfo fieldInfo, - DocValuesProducer valuesProducer, - SegmentWriteState state, - boolean isMerge, - boolean isRefresh - ) throws IOException { - BinaryDocValues values = valuesProducer.getBinary(fieldInfo); - if (KNNCodecUtil.getTotalLiveDocsCount(values) == 0) { + private boolean writeIndex(final KNNVectorValues knnVectorValues) throws IOException { + if (knnVectorValues.totalLiveDocs() == 0) { log.debug("No live docs for field " + fieldInfo.name); - return; + return false; } - final KNNEngine knnEngine = getKNNEngine(fieldInfo); + + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); final String engineFileName = buildEngineFileName( state.segmentInfo.name, knnEngine.getVersion(), @@ -125,95 +95,136 @@ public void createKNNIndex( ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName ).toString(); - state.directory.createOutput(engineFileName, state.context).close(); - NativeIndexInfo indexInfo = getIndexInfo(fieldInfo, valuesProducer, indexPath); - if (isMerge) { - startMergeStats(indexInfo.numDocs, indexInfo.arraySize); - } - if (isRefresh) { - recordRefreshStats(); - } - createIndex(indexInfo, values); - if (isMerge) { - endMergeStats(indexInfo.numDocs, indexInfo.arraySize); - } + + final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine); + boolean indexBuilt = indexBuilder.buildIndex(nativeIndexParams, knnVectorValues); writeFooter(indexPath, engineFileName, state); + + return indexBuilt; } - /** - * Method that makes a native index given the parameters from indexInfo - * @param indexInfo - * @param values - * @throws IOException - */ - protected abstract void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException; + public boolean refreshIndex(final KNNVectorValues knnVectorValues) throws IOException { + if (knnVectorValues.docId() == -1) { + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + } - /** - * Method that generates extra index parameters to be passed to the native library - * @param fieldInfo - * @param knnEngine - * @return extra index parameters to be passed to the native library - * @throws IOException - */ - protected abstract Map getParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException; + if (knnVectorValues.docId() == NO_MORE_DOCS) { + return false; + } - /** - * Method that gets the native vector info - * @param fieldInfo - * @param valuesProducer - * @return native vector info - * @throws IOException - */ - protected abstract NativeVectorInfo getVectorInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer) throws IOException; + recordRefreshStats(); + return writeIndex(knnVectorValues); + } + + public boolean mergeIndex(final KNNVectorValues knnVectorValues) throws IOException { + if (knnVectorValues.docId() == -1) { + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + } - protected VectorTransfer getVectorTransfer(VectorDataType vectorDataType) { - if (VectorDataType.BINARY == vectorDataType) { - return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + if (knnVectorValues.docId() == NO_MORE_DOCS) { + return false; } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + + long arraySize = knnVectorValues.bytesPerVector(); + startMergeStats(knnVectorValues.dimension(), arraySize); + boolean indexWritten = writeIndex(knnVectorValues); + endMergeStats(knnVectorValues.dimension(), arraySize); + + return indexWritten; } - /** - * Method that gets the native index info from a given field - * @param fieldInfo - * @param valuesProducer - * @param indexPath - * @return native index info - * @throws IOException - */ - private NativeIndexInfo getIndexInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer, String indexPath) throws IOException { - int numDocs = (int) KNNCodecUtil.getTotalLiveDocsCount(valuesProducer.getBinary(fieldInfo)); - NativeVectorInfo vectorInfo = getVectorInfo(fieldInfo, valuesProducer); - KNNEngine knnEngine = getKNNEngine(fieldInfo); - NativeIndexInfo indexInfo = NativeIndexInfo.builder() - .fieldInfo(fieldInfo) - .knnEngine(getKNNEngine(fieldInfo)) - .numDocs((int) numDocs) - .vectorInfo(vectorInfo) - .arraySize(numDocs * getBytesPerVector(vectorInfo)) - .parameters(getParameters(fieldInfo, knnEngine)) + // The logic for building parameters need to be cleaned up. There are various cases handled here + // Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type + // TODO: Refactor this so its scalable. Possibly move it out of this class + private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException { + final Map parameters; + final VectorDataType vectorDataType; + if (fieldInfo.attributes().containsKey(MODEL_ID)) { + Model model = getModel(fieldInfo); + vectorDataType = model.getModelMetadata().getVectorDataType(); + parameters = getTemplateParameters(fieldInfo, model); + } else { + vectorDataType = VectorDataType.get( + fieldInfo.attributes().getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + ); + parameters = getParameters(fieldInfo, vectorDataType, knnEngine); + } + + return BuildIndexParams.builder() + .parameters(parameters) + .vectorDataType(vectorDataType) + .knnEngine(knnEngine) .indexPath(indexPath) .build(); - return indexInfo; } - private long getBytesPerVector(NativeVectorInfo vectorInfo) { - if (vectorInfo.vectorDataType == VectorDataType.BINARY) { - return vectorInfo.dimension / 8; + private Map getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); } else { - return vectorInfo.dimension * 4; + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine + && VectorDataType.BINARY.getValue().equals(vectorDataType.getValue()) + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); } + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + return parameters; + } + + private Map getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID)); + parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob()); + IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + return parameters; } - private static KNNEngine getKNNEngine(@NonNull FieldInfo field) { - final String modelId = field.attributes().get(MODEL_ID); - if (modelId != null) { - var model = ModelCache.getInstance().get(modelId); - return model.getModelMetadata().getKnnEngine(); + private Model getModel(FieldInfo fieldInfo) { + String modelId = fieldInfo.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); - return KNNEngine.getEngine(engineName); + return model; } private void startMergeStats(int numDocs, long arraySize) { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java deleted file mode 100644 index 3a410e801d..0000000000 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratch.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.nativeindex; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.HashMap; -import java.util.Map; - -import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.util.BytesRef; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.IndexUtil; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.jni.JNIService; - -import lombok.extern.log4j.Log4j2; - -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; - -/** - * Class to build the KNN index from scratch and write it to disk - */ -@Log4j2 -public class NativeIndexWriterScratch extends NativeIndexWriter { - - protected NativeVectorInfo getVectorInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer) throws IOException { - // Hack to get the data metrics from the first document. We account for this in KNNCodecUtil. - BinaryDocValues testValues = valuesProducer.getBinary(fieldInfo); - testValues.nextDoc(); - BytesRef firstDoc = testValues.binaryValue(); - VectorDataType vectorDataType = VectorDataType.get( - fieldInfo.attributes().getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) - ); - int dimension = 0; - if (vectorDataType == VectorDataType.BINARY) { - dimension = firstDoc.length * 8; - } else { - dimension = firstDoc.length / 4; - } - NativeVectorInfo vectorInfo = NativeVectorInfo.builder().vectorDataType(vectorDataType).dimension(dimension).build(); - return vectorInfo; - } - - protected Map getParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { - Map parameters = new HashMap<>(); - Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); - - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); - } - - // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine - && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); - } - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - return parameters; - } - - protected void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException { - VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorInfo().getVectorDataType()); - KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(values, vectorTransfer, false); - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex( - batch.docs, - batch.getVectorAddress(), - batch.getDimension(), - indexInfo.getIndexPath(), - indexInfo.getParameters(), - indexInfo.getKnnEngine() - ); - return null; - }); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java deleted file mode 100644 index c3848d7e45..0000000000 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterScratchIter.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.nativeindex; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.Map; - -import org.apache.lucene.index.BinaryDocValues; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.jni.JNIService; - -import lombok.extern.log4j.Log4j2; - -/** - * Class to build the KNN index from scratch iteratively and write it to disk - */ -@Log4j2 -public class NativeIndexWriterScratchIter extends NativeIndexWriterScratch { - - @Override - protected void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException { - long indexAddress = initIndexFromScratch( - indexInfo.getNumDocs(), - indexInfo.getVectorInfo().getDimension(), - indexInfo.getKnnEngine(), - indexInfo.getParameters() - ); - while (true) { - KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch( - values, - getVectorTransfer(indexInfo.getVectorInfo().getVectorDataType()), - true - ); - insertToIndex(batch, indexInfo.getKnnEngine(), indexAddress, indexInfo.getParameters()); - if (batch.finished) { - break; - } - } - writeIndex(indexAddress, indexInfo.getIndexPath(), indexInfo.getKnnEngine(), indexInfo.getParameters()); - } - - private long initIndexFromScratch(long size, int dim, KNNEngine knnEngine, Map parameters) throws IOException { - return AccessController.doPrivileged((PrivilegedAction) () -> { - return JNIService.initIndexFromScratch(size, dim, parameters, knnEngine); - }); - } - - private void insertToIndex(KNNCodecUtil.VectorBatch batch, KNNEngine knnEngine, long indexAddress, Map parameters) - throws IOException { - if (batch.docs.length == 0) { - log.debug("Index insertion called with a batch without docs."); - return; - } - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.insertToIndex(batch.docs, batch.getVectorAddress(), batch.getDimension(), parameters, indexAddress, knnEngine); - return null; - }); - } - - private void writeIndex(long indexAddress, String indexPath, KNNEngine knnEngine, Map parameters) throws IOException { - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.writeIndex(indexPath, indexAddress, knnEngine, parameters); - return null; - }); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java deleted file mode 100644 index f1cb84f979..0000000000 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriterTemplate.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.nativeindex; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.HashMap; -import java.util.Map; - -import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.IndexUtil; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.jni.JNIService; - -import lombok.extern.log4j.Log4j2; - -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; - -/** - * Abstract class to build the KNN index from a template model and write it to disk - */ -@Log4j2 -public class NativeIndexWriterTemplate extends NativeIndexWriter { - - protected void createIndex(NativeIndexInfo indexInfo, BinaryDocValues values) throws IOException { - String modelId = indexInfo.getFieldInfo().attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - byte[] modelBlob = model.getModelBlob(); - IndexUtil.updateVectorDataTypeToParameters(indexInfo.getParameters(), model.getModelMetadata().getVectorDataType()); - // This is carried over from the old index creation process. Why can't we get the vector data type - // by just reading it from the field? - KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch( - values, - getVectorTransfer(indexInfo.getVectorInfo().getVectorDataType()), - false - ); - - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndexFromTemplate( - batch.docs, - batch.getVectorAddress(), - batch.getDimension(), - indexInfo.getIndexPath(), - modelBlob, - indexInfo.getParameters(), - indexInfo.getKnnEngine() - ); - return null; - }); - } - - @Override - protected Map getParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { - Map parameters = new HashMap<>(); - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - String modelId = fieldInfo.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); - return parameters; - } - - @Override - protected NativeVectorInfo getVectorInfo(FieldInfo fieldInfo, DocValuesProducer valuesProducer) throws IOException { - BinaryDocValues testValues = valuesProducer.getBinary(fieldInfo); - testValues.nextDoc(); - BytesRef firstDoc = testValues.binaryValue(); - String modelId = fieldInfo.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); - int dimension = 0; - if (vectorDataType == VectorDataType.BINARY) { - dimension = firstDoc.length * 8; - } else { - dimension = firstDoc.length / 4; - } - NativeVectorInfo vectorInfo = NativeVectorInfo.builder().vectorDataType(vectorDataType).dimension(dimension).build(); - return vectorInfo; - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java new file mode 100644 index 0000000000..d1a0645caa --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.model; + +import lombok.Builder; +import lombok.Value; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.Map; + +@Value +@Builder +public class BuildIndexParams { + KNNEngine knnEngine; + String indexPath; + VectorDataType vectorDataType; + Map parameters; + // TODO: Add quantization state as parameter to build index +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java new file mode 100644 index 0000000000..d14a766369 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer binary vectors to off heap memory. + */ +public final class OffHeapBinaryVectorTransfer extends OffHeapQuantizedVectorTransfer { + + public OffHeapBinaryVectorTransfer(KNNBinaryVectorValues vectorValues, Long batchSize) throws IOException { + super(vectorValues, batchSize, (vector, state) -> vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); + } + + public OffHeapBinaryVectorTransfer(KNNBinaryVectorValues vectorValues) throws IOException { + this(vectorValues, null); + } + + @Override + public void close() { + super.close(); + // TODO: deallocate the memory location + } + + @Override + protected long transfer(List vectorsToTransfer, long bytesPerVector, boolean append) throws IOException { + // TODO: call to JNIService to transfer vector + return 0L; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java new file mode 100644 index 0000000000..6f98766195 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer byte vectors to off heap memory. + */ +public final class OffHeapByteVectorTransfer extends OffHeapQuantizedVectorTransfer { + + // TODO: Replace with KNNByteVectorValues + public OffHeapByteVectorTransfer(KNNVectorValues vectorValues, final Long batchSize) throws IOException { + super(vectorValues, batchSize, (vector, state) -> vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); + } + + // TODO: Replace with KNNByteVectorValues + public OffHeapByteVectorTransfer(KNNVectorValues vectorValues) throws IOException { + this(vectorValues, null); + } + + @Override + protected long transfer(List batch, long bytesPerVector, boolean append) throws IOException { + return JNICommons.storeByteVectorData(getVectorAddress(), batch.toArray(new byte[][] {}), batchSize * bytesPerVector, append); + } + + @Override + public void close() { + super.close(); + JNICommons.freeByteVectorData(getVectorAddress()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java new file mode 100644 index 0000000000..182c48b0ec --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer float vectors to off heap memory. + */ +public final class OffHeapFloatVectorTransfer extends OffHeapQuantizedVectorTransfer { + + public OffHeapFloatVectorTransfer(KNNFloatVectorValues vectorValues, Long batchSize) throws IOException { + super(vectorValues, batchSize, (vector, state) -> vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); + } + + public OffHeapFloatVectorTransfer(KNNFloatVectorValues vectorValues) throws IOException { + this(vectorValues, null); + } + + @Override + protected long transfer(List vectorsToTransfer, long bytesPerVector, boolean append) throws IOException { + return JNICommons.storeVectorData( + getVectorAddress(), + vectorsToTransfer.toArray(new float[][] {}), + this.batchSize * bytesPerVector, + append + ); + } + + @Override + public void close() { + super.close(); + JNICommons.freeVectorData(getVectorAddress()); + } + +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java new file mode 100644 index 0000000000..a1a0b0434a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.Getter; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.List; +import java.util.function.BiFunction; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.KNNVectorUtil.createArrayList; + +/** + * The class is intended to transfer {@link KNNVectorValues} to off heap memory. If also provides and ability to quantize the vector + * before it is transferred to offHeap memory. The class is not thread safe + * + * @param an array of primitive type + * @param an array of primitive type after being quantized + */ +abstract class OffHeapQuantizedVectorTransfer implements VectorTransfer { + + protected static final int DEFAULT_COMPRESSION_FACTOR = 1; + + @Getter + private long vectorAddress; + @Getter + private int[] transferredDocsIds; + private final int transferLimit; + // Keeping this as a member variable as this should not be changed considering the vector address is reused between batches + protected long batchSize; + protected final long bytesPerVector; + + private final List vectorsToTransfer; + private final List transferredDocIdsList; + + private final KNNVectorValues vectorValues; + + // TODO: Replace with actual quantization parameters + private final BiFunction quantizer; + private final String quantizationState; + + private final int compressionFactor; + + public OffHeapQuantizedVectorTransfer( + final KNNVectorValues vectorValues, + final Long batchSize, + final BiFunction quantizer, + final String quantizationState, + final int compressionFactor + ) { + assert vectorValues.docId() != -1 : "vectorValues docId must be set, iterate it once for vector transfer to succeed"; + assert vectorValues.docId() != NO_MORE_DOCS : "vectorValues already iterated, Nothing to transfer"; + + this.quantizer = quantizer; + this.quantizationState = quantizationState; + this.bytesPerVector = vectorValues.bytesPerVector() / compressionFactor; + this.compressionFactor = compressionFactor; + this.transferLimit = (int) Math.max(1, (int) KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + this.batchSize = batchSize == null ? transferLimit : batchSize; + this.vectorsToTransfer = createArrayList(this.batchSize); + this.transferredDocIdsList = createArrayList(this.batchSize); + this.vectorValues = vectorValues; + this.vectorAddress = 0; // we can allocate initial memory here, currently storeVectorData takes care of it + } + + @Override + public void transfer() throws IOException { + if (vectorValues.docId() == NO_MORE_DOCS) { + // Throwing instead of returning so there is no way client can go into an infinite loop + throw new IllegalStateException("No more vectors available to transfer"); + } + + transferredDocsIds = new int[1]; + V vector = quantizer.apply(vectorValues.getVector(), quantizationState); + vectorsToTransfer.add(vector); + + transfer(vectorsToTransfer, vectorValues.bytesPerVector() / compressionFactor, false); + transferredDocsIds[0] = vectorValues.docId(); + vectorsToTransfer.clear(); + vectorValues.nextDoc(); + } + + @Override + public void transferBatch() throws IOException { + if (vectorValues.docId() == NO_MORE_DOCS) { + // Throwing instead of returning so there is no way client can go into an infinite loop + throw new IllegalStateException("No more vectors available to transfer"); + } + + assert vectorsToTransfer.isEmpty() : "Last batch wasn't transferred"; + assert transferredDocIdsList.isEmpty() : "Last batch wasn't transferred"; + + int totalDocsTransferred = 0; + boolean freshBatch = true; + + // Create non-final QuantizationOutput once here and then reuse the output + while (vectorValues.docId() != NO_MORE_DOCS && totalDocsTransferred < batchSize) { + V vector = quantizer.apply(vectorValues.getVector(), quantizationState); + + transferredDocIdsList.add(vectorValues.docId()); + vectorsToTransfer.add(vector); + if (vectorsToTransfer.size() == transferLimit) { + vectorAddress = transfer(vectorsToTransfer, bytesPerVector, !freshBatch); + vectorsToTransfer.clear(); + freshBatch = false; + } + vectorValues.nextDoc(); + totalDocsTransferred++; + } + + // Handle batchSize < transferLimit + if (!vectorsToTransfer.isEmpty()) { + vectorAddress = transfer(vectorsToTransfer, bytesPerVector, !freshBatch); + vectorsToTransfer.clear(); + } + + this.transferredDocsIds = new int[transferredDocIdsList.size()]; + for (int i = 0; i < transferredDocIdsList.size(); i++) { + transferredDocsIds[i] = transferredDocIdsList.get(i); + } + transferredDocIdsList.clear(); + } + + @Override + public boolean hasNext() { + return vectorValues.docId() != NO_MORE_DOCS; + } + + @Override + public void close() { + transferredDocIdsList.clear(); + transferredDocsIds = null; + vectorAddress = 0; + } + + protected abstract long transfer(final List vectorsToTransfer, final long bytesPerVector, final boolean append) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java index 5c847fcc4b..f8847e424e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java @@ -5,57 +5,43 @@ package org.opensearch.knn.index.codec.transfer; -import lombok.Data; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; +import java.io.Closeable; +import java.io.IOException; /** - * Abstract class to transfer vector value from Java to native memory + * An interface to transfer vectors from one memory location to another + * Class is Closeable to be able to release memory once done */ -@Data -public abstract class VectorTransfer { - protected final long vectorsStreamingMemoryLimit; - protected long totalLiveDocs; - protected long vectorsPerTransfer; - protected long vectorAddress; - protected int dimension; - - public VectorTransfer(final long vectorsStreamingMemoryLimit) { - this.vectorsStreamingMemoryLimit = vectorsStreamingMemoryLimit; - this.vectorsPerTransfer = Integer.MIN_VALUE; - } +public interface VectorTransfer extends Closeable { /** - * Initialize the transfer - * - * @param totalLiveDocs total number of vectors to be transferred + * Transfer a single vector from one location to another + * @throws IOException */ - abstract public void init(final long totalLiveDocs); + void transfer() throws IOException; /** - * Transfer a single vector - * - * @param bytesRef a vector in bytes format + * Transfer a batch of vectors from one location to another + * The batch size here is intended to be constant for multiple transfers so should be encapsulated in the + * implementation. A new batch size should require another instance + * @throws IOException */ - abstract public void transfer(final BytesRef bytesRef); + void transferBatch() throws IOException; /** - * Close the transfer + * Indicates if there are more vectors to transfer + * @return */ - abstract public void close(); + boolean hasNext(); /** - * Get serialization mode of given byte stream - * - * @param bytesRef bytes of a vector - * @return serialization mode + * Gives the docIds for transfered vectors + * @return */ - abstract public SerializationMode getSerializationMode(final BytesRef bytesRef); + int[] getTransferredDocsIds(); /** - * Get number of documents not transferred - * - * @return number of documents not transferred + * @return the memory address of the vectors transferred */ - abstract public int numPendingDocs(); + long getVectorAddress(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java deleted file mode 100644 index cf4066828f..0000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.util.ArrayList; -import java.util.List; - -/** - * Vector transfer for byte - */ -public class VectorTransferByte extends VectorTransfer { - private List vectorList; - - public VectorTransferByte(final long vectorsStreamingMemoryLimit) { - super(vectorsStreamingMemoryLimit); - vectorList = new ArrayList<>(); - } - - @Override - public void init(final long totalLiveDocs) { - this.totalLiveDocs = totalLiveDocs; - vectorList.clear(); - } - - @Override - public void transfer(final BytesRef bytesRef) { - dimension = bytesRef.length * 8; - if (vectorsPerTransfer == Integer.MIN_VALUE) { - // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with length of 5, then per - // transfer we have to send 100/5 => 20 vectors. - vectorsPerTransfer = vectorsStreamingMemoryLimit / bytesRef.length; - // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that - // we are sending minimum number of vectors. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = 1; - } - } - - vectorList.add(ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length)); - if (vectorList.size() == vectorsPerTransfer) { - transfer(); - } - } - - @Override - public void close() { - transfer(); - } - - @Override - public SerializationMode getSerializationMode(final BytesRef bytesRef) { - return SerializationMode.COLLECTIONS_OF_BYTES; - } - - @Override - public int numPendingDocs() { - return vectorList.size(); - } - - private void transfer() { - int lengthOfVector = dimension / 8; - if (totalLiveDocs != 0) { - vectorAddress = JNICommons.storeByteVectorData( - vectorAddress, - vectorList.toArray(new byte[][] {}), - totalLiveDocs * lengthOfVector - ); - } else { - vectorAddress = JNICommons.storeByteVectorData( - vectorAddress, - vectorList.toArray(new byte[][] {}), - vectorList.size() * lengthOfVector - ); - } - vectorList.clear(); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java deleted file mode 100644 index b4ce95bb1c..0000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.util.ArrayList; -import java.util.List; - -/** - * Vector transfer for float - */ -public class VectorTransferFloat extends VectorTransfer { - private List vectorList; - - public VectorTransferFloat(final long vectorsStreamingMemoryLimit) { - super(vectorsStreamingMemoryLimit); - vectorList = new ArrayList<>(); - } - - @Override - public void init(final long totalLiveDocs) { - this.totalLiveDocs = totalLiveDocs; - vectorList.clear(); - } - - @Override - public void transfer(final BytesRef bytesRef) { - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(bytesRef); - final float[] vector = vectorSerializer.byteToFloatArray(bytesRef); - dimension = vector.length; - - if (vectorsPerTransfer == Integer.MIN_VALUE) { - // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with 5 dimension, then per - // transfer we have to send 100/(5 * 4) => 5 vectors. - vectorsPerTransfer = vectorsStreamingMemoryLimit / ((long) dimension * Float.BYTES); - // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that - // we are sending minimum number of vectors. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = 1; - } - } - - vectorList.add(vector); - if (vectorList.size() == vectorsPerTransfer) { - transfer(); - } - } - - @Override - public void close() { - transfer(); - } - - @Override - public SerializationMode getSerializationMode(final BytesRef bytesRef) { - return KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef); - } - - @Override - public int numPendingDocs() { - return vectorList.size(); - } - - private void transfer() { - if (totalLiveDocs != 0) { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); - } else { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), vectorList.size() * dimension); - } - vectorList.clear(); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index e30154c2fc..51100a1e0f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,78 +5,14 @@ package org.opensearch.knn.index.codec.util; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - @AllArgsConstructor - public static final class VectorBatch { - public int[] docs; - @Getter - @Setter - private long vectorAddress; - @Getter - @Setter - private int dimension; - public boolean finished; - } - - /** - * Extract docIds and vectors from binary doc values. - * - * @param values Binary doc values - * @param vectorTransfer Utility to make transfer - * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors - * @throws IOException thrown when unable to get binary of vectors - */ - public static KNNCodecUtil.VectorBatch getVectorBatch( - final BinaryDocValues values, - final VectorTransfer vectorTransfer, - boolean iterative - ) throws IOException { - List docIdList = new ArrayList<>(); - if (iterative) { - // Initializing with a value of zero means to only allocate as much memory on JNI as - // we have inserted for vectors in java side - vectorTransfer.init(0); - } else { - vectorTransfer.init(getTotalLiveDocsCount(values)); - } - for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { - BytesRef bytesref = values.binaryValue(); - vectorTransfer.transfer(bytesref); - docIdList.add(doc); - // Semi-hacky way to check if the streaming limit has been reached - if (iterative && vectorTransfer.numPendingDocs() == 0) { - break; - } - } - vectorTransfer.close(); - - boolean finished = values.docID() == DocIdSetIterator.NO_MORE_DOCS; - - return new KNNCodecUtil.VectorBatch( - docIdList.stream().mapToInt(Integer::intValue).toArray(), - vectorTransfer.getVectorAddress(), - vectorTransfer.getDimension(), - finished - ); - } - /** * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. * @param numVectors number of vectors in the array diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java index f38099b74a..9066167318 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -25,17 +25,8 @@ public class KNNBinaryVectorValues extends KNNVectorValues { @Override public byte[] getVector() throws IOException { final byte[] vector = VectorValueExtractorStrategy.extractBinaryVector(vectorValuesIterator); - this.dimension = vector.length; + this.dimension = vector.length * Byte.SIZE; + this.bytesPerVector = vector.length; return vector; } - - /** - * Binary Vector values gets stored as byte[], hence for dimension of the binary vector we have to multiply the - * byte[] size with {@link Byte#SIZE} - * @return int - */ - @Override - public int dimension() { - return super.dimension() * Byte.SIZE; - } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java index ccbbfab77b..5eca4944d7 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -26,6 +26,7 @@ public class KNNByteVectorValues extends KNNVectorValues { public byte[] getVector() throws IOException { final byte[] vector = VectorValueExtractorStrategy.extractByteVector(vectorValuesIterator); this.dimension = vector.length; + this.bytesPerVector = vector.length; return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index 174f3a89ee..dfc5384038 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -24,6 +24,7 @@ public class KNNFloatVectorValues extends KNNVectorValues { public float[] getVector() throws IOException { final float[] vector = VectorValueExtractorStrategy.extractFloatVector(vectorValuesIterator); this.dimension = vector.length; + this.bytesPerVector = vector.length * 4L; return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index c4ed64bc25..3e94171e3f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -23,6 +23,7 @@ public abstract class KNNVectorValues { protected final KNNVectorValuesIterator vectorValuesIterator; protected int dimension; + protected long bytesPerVector; protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { this.vectorValuesIterator = vectorValuesIterator; @@ -46,6 +47,15 @@ public int dimension() { return dimension; } + /** + * Size of a vector in bytes is returned. Do call getVector function first before calling this function otherwise you will get 0 value. + * @return int + */ + public long bytesPerVector() { + assert docId() != -1 && bytesPerVector != 0 : "Cannot get bytesPerVector before we retrieve a vector from KNNVectorValues"; + return bytesPerVector; + } + /** * Returns the total live docs for KNNVectorValues. * @return long diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index 4c6230a460..02674603ce 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -11,9 +11,14 @@ package org.opensearch.knn.indices; +import lombok.experimental.UtilityClass; + +import java.util.Locale; + /** * A utility class for models. */ +@UtilityClass public class ModelUtil { public static void blockCommasInModelDescription(String description) { @@ -33,4 +38,16 @@ public static boolean isModelCreated(ModelMetadata modelMetadata) { return modelMetadata.getState().equals(ModelState.CREATED); } + public static ModelMetadata getModelMetadata(final String modelId) { + if (modelId == null || modelId.isEmpty()) { + return null; + } + Model model = ModelCache.getInstance().get(modelId); + ModelMetadata modelMetadata = model.getModelMetadata(); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); + } + return modelMetadata; + } + } diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index 31a8f43cc8..c7222738ef 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -36,16 +36,59 @@ public class JNICommons { * will throw Exception. * *

- * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can - * lead to data corruption. + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. *

* - * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D float array containing data to be stored in native memory. + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address where the data is stored. */ - public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + public static long storeVectorData(long memoryAddress, float[][] data, long initialCapacity) { + return storeVectorData(memoryAddress, data, initialCapacity, true); + } + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location + * @return memory address where the data is stored. + */ + public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity, boolean append); + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity) { + return storeByteVectorData(memoryAddress, data, initialCapacity, true); + } /** * This is utility function that can be used to store data in native memory. This function will allocate memory for @@ -55,24 +98,25 @@ public class JNICommons { * will throw Exception. * *

- * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can - * lead to data corruption. + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. *

* - * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D byte array containing data to be stored in native memory. + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location * @return memory address where the data is stored. */ - public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity); + public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity, boolean append); /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * *

- * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. *

* * @param memoryAddress address to be freed. @@ -81,11 +125,11 @@ public class JNICommons { /** * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * *

- * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. *

* * @param memoryAddress address to be freed. diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index d428d1ee49..7427af6744 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -14,9 +14,9 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.common.Nullable; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryResult; +import org.opensearch.knn.index.util.IndexUtil; import java.util.Locale; import java.util.Map; @@ -29,10 +29,10 @@ public class JNIService { * Initialize an index for the native library. Takes in numDocs to * allocate the correct amount of memory. * - * @param numDocs number of documents to be added - * @param dim dimension of the vector to be indexed + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed * @param parameters parameters to build index - * @param knnEngine knn engine + * @param knnEngine knn engine * @return address of the index in memory */ public static long initIndexFromScratch(long numDocs, int dim, Map parameters, KNNEngine knnEngine) { @@ -52,12 +52,12 @@ public static long initIndexFromScratch(long numDocs, int dim, Map parameters) { if (KNNEngine.FAISS == knnEngine) { @@ -111,12 +111,12 @@ public static void writeIndex(String indexPath, long indexAddress, KNNEngine knn * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this * issue * - * @param ids array of ids mapping to the data passed in + * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored - * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to - * @param parameters parameters to build index - * @param knnEngine engine to build index for + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param parameters parameters to build index + * @param knnEngine engine to build index for */ public static void createIndex( int[] ids, @@ -140,13 +140,13 @@ public static void createIndex( /** * Create an index for the native library with a provided template index * - * @param ids array of ids mapping to the data passed in + * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored - * @param dim dimension of vectors to be indexed - * @param indexPath path to save index file to - * @param templateIndex empty template index - * @param parameters parameters to build index - * @param knnEngine engine to build index for + * @param dim dimension of vectors to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters parameters to build index + * @param knnEngine engine to build index for */ public static void createIndexFromTemplate( int[] ids, @@ -252,13 +252,13 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr, /** * Query an index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param methodParameters method parameter - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryIndex( @@ -301,13 +301,13 @@ public static KNNQueryResult[] queryIndex( /** * Query a binary index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param methodParameters method parameter - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryBinaryIndex( @@ -407,12 +407,12 @@ public static byte[] trainIndex(Map indexParameters, int dimensi /** *

- * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long, boolean)} *

* Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param trainingData data to be transferred + * @param trainingData data to be transferred * @return pointer to native memory location of training data */ @Deprecated(since = "2.14.0", forRemoval = true) @@ -423,15 +423,15 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) /** * Range search index for a given query vector * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param radius search within radius threshold - * @param methodParameters parameters to be used when loading index - * @param knnEngine engine to query index + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param methodParameters parameters to be used when loading index + * @param knnEngine engine to query index * @param indexMaxResultWindow maximum number of results to return - * @param filteredIds list of doc ids to include in the query result - * @param filterIdsType how to filter ids: Batch or BitMap - * @param parentIds parent ids of the vectors + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType how to filter ids: Batch or BitMap + * @param parentIds parent ids of the vectors * @return KNNQueryResult array of neighbors within radius */ public static KNNQueryResult[] radiusQueryIndex( diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index e87531561a..277211ae6f 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -104,7 +104,7 @@ public void testAddBinaryField_withKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; } }; @@ -141,7 +141,7 @@ public void testAddBinaryField_withoutKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, state) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; } }; @@ -159,7 +159,6 @@ public void testAddKNNBinaryField_noVectors() throws IOException { 128 ); Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); - Long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); Long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); Long initialMergeSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); Long initialMergeDocs = KNNGraphValue.MERGE_TOTAL_DOCS.getValue(); @@ -177,9 +176,8 @@ public void testAddKNNBinaryField_noVectors() throws IOException { SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true); assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); - assertEquals(initialRefreshOperations, KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeOperations, KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeSize, KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); assertEquals(initialMergeDocs, KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); @@ -223,7 +221,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -232,7 +229,61 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + } + + public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") + .addAttribute(KNNConstants.HNSW_ALGO_M, "16") + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -245,7 +296,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -290,7 +340,6 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -298,7 +347,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, false); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -312,9 +361,6 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException { @@ -357,7 +403,6 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -366,7 +411,7 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -379,7 +424,6 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException assertBinaryIndexLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension, dataType); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -456,7 +500,6 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -465,7 +508,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -478,7 +521,6 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -550,6 +592,6 @@ public void testAddBinaryField_luceneEngine_noInvocations_addKNNBinary() throws knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); - verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false), eq(true)); + verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 322b714f2d..feb6b93747 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -46,6 +46,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; @@ -99,10 +100,13 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc byte[] byteVector = { 6, 14 }; addFieldToIndex( - new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32)), + new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT)), + indexWriter + ); + addFieldToIndex( + new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE, VectorDataType.BYTE)), indexWriter ); - addFieldToIndex(new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE)), indexWriter); final IndexReader indexReader = indexWriter.getReader(); // ensuring segments are created indexWriter.flush(); @@ -187,17 +191,19 @@ private void addFieldToIndex(final Field vectorField, final RandomIndexWriter in indexWriter.addDocument(doc1); } - private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding) { + private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding, VectorDataType vectorDataType) { FieldType nativeVectorField = new FieldType(); // TODO: Replace this with the default field which will be created in mapper for Native Engines with KNNVectorsFormat nativeVectorField.setTokenized(false); nativeVectorField.setIndexOptions(IndexOptions.NONE); nativeVectorField.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); nativeVectorField.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); - nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); + nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()); nativeVectorField.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + nativeVectorField.putAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + nativeVectorField.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); nativeVectorField.setVectorAttributes( dimension, vectorEncoding, diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java new file mode 100644 index 0000000000..780d164ae3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java @@ -0,0 +1,162 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.SneakyThrows; +import org.apache.lucene.index.DocsWithFieldSet; +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.util.Map; + +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; + +public class OffHeapVectorTransferTests extends KNNTestCase { + + @Mock + ClusterSettings clusterSettings; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + KNNSettings.state().setClusterService(clusterService); + } + + @SneakyThrows + public void testFloatTransfer() { + // Given + when(clusterSettings.get(KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING)).thenReturn(new ByteSizeValue(16)); + final Map docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + //Transfer 1 vector + int[] expectedDocIds = new int[] { 0, 1, 2 }; + KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + VectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues); + testTransferSingleVector(vectorTransfer, expectedDocIds); + + //Transfer batch, limit == batch size + knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0, 1 }, { 2 } }, 2); + + //Transfer batch, limit < batch size + knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues, 5L); + vectorTransfer.transferBatch(); + assertNotEquals(0, vectorTransfer.getVectorAddress()); + assertArrayEquals(new int[] {0, 1, 2}, vectorTransfer.getTransferredDocsIds()); + + //Transfer batch, limit > batch size + knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues, 1L); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0 }, { 1 }, { 2 } }, 3); + } + + @SneakyThrows + public void testByteTransfer() { + // Given + when(clusterSettings.get(KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING)).thenReturn(new ByteSizeValue(4)); + final Map docs = Map.of(0, new byte[] { 1, 2 }, 1, new byte[] { 2, 3 }, 2, new byte[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + //Transfer 1 vector + int[] expectedDocIds = new int[] { 0, 1, 2 }; + KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + VectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(knnVectorValues); + testTransferSingleVector(vectorTransfer, expectedDocIds); + + //Transfer batch, limit == batch size + knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapByteVectorTransfer(knnVectorValues); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0, 1 }, { 2 } }, 2); + + //Transfer batch, limit < batch size + knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapByteVectorTransfer(knnVectorValues, 5L); + vectorTransfer.transferBatch(); + assertNotEquals(0, vectorTransfer.getVectorAddress()); + assertArrayEquals(new int[] {0, 1, 2}, vectorTransfer.getTransferredDocsIds()); + + //Transfer batch, limit > batch size + knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); + knnVectorValues.nextDoc(); knnVectorValues.getVector(); + vectorTransfer = new OffHeapByteVectorTransfer(knnVectorValues, 1L); + testTransferBatchVectors(vectorTransfer, new int[][] { { 0 }, { 1 }, { 2 } }, 3); + } + + // TODO: Add a unit test for binary + + @SneakyThrows + private void testTransferSingleVector(VectorTransfer vectorTransfer, int[] expectedDocIds) { + long vectorAddress = 0L; + try { + int iteration = 0; + while (vectorTransfer.hasNext()) { + vectorTransfer.transfer(); + if (iteration != 0) { + assertEquals("Vector address shouldn't be different", vectorAddress, vectorTransfer.getVectorAddress()); + } else { + assertEquals(0L, vectorAddress); + vectorAddress = vectorTransfer.getVectorAddress(); + } + assertEquals(expectedDocIds[iteration], vectorTransfer.getTransferredDocsIds()[0]); + iteration++; + } + assertEquals(iteration, expectedDocIds.length); + } finally { + vectorTransfer.close(); + assertEquals(vectorTransfer.getVectorAddress(), 0); + assertNull(vectorTransfer.getTransferredDocsIds()); + } + } + + @SneakyThrows + private void testTransferBatchVectors(VectorTransfer vectorTransfer, int[][] expectedDocIds, int expectedIterations) { + long vectorAddress = 0L; + try { + int iteration = 0; + while (vectorTransfer.hasNext()) { + vectorTransfer.transferBatch(); + if (iteration != 0) { + assertEquals("Vector address shouldn't be different", vectorAddress, vectorTransfer.getVectorAddress()); + } else { + assertEquals(0, vectorAddress); + vectorAddress = vectorTransfer.getVectorAddress(); + } + assertArrayEquals(expectedDocIds[iteration], vectorTransfer.getTransferredDocsIds()); + iteration++; + } + assertEquals(expectedIterations, iteration); + } finally { + vectorTransfer.close(); + assertEquals(vectorTransfer.getVectorAddress(), 0); + assertNull(vectorTransfer.getTransferredDocsIds()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java deleted file mode 100644 index 2f091a0355..0000000000 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.io.IOException; -import java.util.Random; - -import static org.junit.Assert.assertNotEquals; - -public class VectorTransferByteTests extends TestCase { - @SneakyThrows - public void testTransfer_whenCalled_thenAdded() { - final BytesRef bytesRef1 = getByteArrayOfVectors(20); - final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(40); - try { - vectorTransfer.init(2); - - vectorTransfer.transfer(bytesRef1); - // flush is not called - assertEquals(0, vectorTransfer.getVectorAddress()); - - vectorTransfer.transfer(bytesRef2); - // flush should be called - assertNotEquals(0, vectorTransfer.getVectorAddress()); - } finally { - if (vectorTransfer.getVectorAddress() != 0) { - JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); - } - } - } - - @SneakyThrows - public void testSerializationMode_whenCalled_thenReturn() { - final BytesRef bytesRef = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(1000); - - // Verify - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, vectorTransfer.getSerializationMode(bytesRef)); - } - - private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException { - byte[] vector = new byte[vectorLength]; - new Random().nextBytes(vector); - return new BytesRef(vector); - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java deleted file mode 100644 index 620fd7c65f..0000000000 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.jni.JNICommons; - -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.Random; -import java.util.stream.IntStream; - -import static org.junit.Assert.assertNotEquals; - -public class VectorTransferFloatTests extends TestCase { - @SneakyThrows - public void testTransfer_whenCalled_thenAdded() { - final BytesRef bytesRef1 = getByteArrayOfVectors(20); - final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(160); - try { - vectorTransfer.init(2); - - vectorTransfer.transfer(bytesRef1); - // flush is not called - assertEquals(0, vectorTransfer.getVectorAddress()); - - vectorTransfer.transfer(bytesRef2); - // flush should be called - assertNotEquals(0, vectorTransfer.getVectorAddress()); - } finally { - if (vectorTransfer.getVectorAddress() != 0) { - JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); - } - } - } - - @SneakyThrows - public void testSerializationMode_whenCalled_thenReturn() { - final BytesRef bytesRef = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(1000); - - // Verify - assertEquals(KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef), vectorTransfer.getSerializationMode(bytesRef)); - } - - private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException { - float[] vector = new float[vectorLength]; - IntStream.range(0, vectorLength).forEach(index -> vector[index] = new Random().nextFloat()); - - final ByteArrayOutputStream bas = new ByteArrayOutputStream(); - final DataOutputStream ds = new DataOutputStream(bas); - for (float f : vector) { - ds.writeFloat(f); - } - final byte[] vectorAsCollectionOfFloats = bas.toByteArray(); - return new BytesRef(vectorAsCollectionOfFloats); - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 5a68d96d4d..dbea6375b2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -6,51 +6,11 @@ package org.opensearch.knn.index.codec.util; import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import java.util.Arrays; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { - @SneakyThrows - public void testGetPair_whenCalled_thenReturn() { - long liveDocCount = 1l; - int[] docId = { 2 }; - long vectorAddress = 3l; - int dimension = 4; - BytesRef bytesRef = new BytesRef(); - - BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); - when(binaryDocValues.cost()).thenReturn(liveDocCount); - when(binaryDocValues.nextDoc()).thenReturn(docId[0], NO_MORE_DOCS); - when(binaryDocValues.binaryValue()).thenReturn(bytesRef); - - VectorTransfer vectorTransfer = mock(VectorTransfer.class); - when(vectorTransfer.getVectorAddress()).thenReturn(vectorAddress); - when(vectorTransfer.getDimension()).thenReturn(dimension); - - // Run - KNNCodecUtil.VectorBatch batch = KNNCodecUtil.getVectorBatch(binaryDocValues, vectorTransfer, false); - - // Verify - verify(vectorTransfer).init(liveDocCount); - verify(vectorTransfer).transfer(any(BytesRef.class)); - verify(vectorTransfer).close(); - - assertTrue(Arrays.equals(docId, batch.docs)); - assertEquals(vectorAddress, batch.getVectorAddress()); - assertEquals(dimension, batch.getDimension()); - } public void testCalculateArraySize() { int numVectors = 4; diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index e2b831e6ee..741717116a 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -19,7 +19,7 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.script.KNNScoringUtil; @@ -399,11 +399,11 @@ private void initBinaryData() { } public long loadDataToMemoryAddress() { - return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length); + return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length, true); } public long loadBinaryDataToMemoryAddress() { - return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length); + return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true); } @AllArgsConstructor