Skip to content

Commit

Permalink
Integrates FAISS iterative builds with NativeEngines990KnnVectorsFormat
Browse files Browse the repository at this point in the history
Changes include reusing the same vector buffer in the JNI layer

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Aug 12, 2024
1 parent e5823fb commit af635b3
Show file tree
Hide file tree
Showing 42 changed files with 1,355 additions and 1,040 deletions.
17 changes: 15 additions & 2 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@ namespace knn_jni {
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector
* without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating
* and deallocating when the memory address needs to be reused.
*
* CAUTION: The behavior is undefined if the memory address is deallocated and the method is called
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector<float> where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);

/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
Expand All @@ -33,12 +40,18 @@ namespace knn_jni {
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector
* without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating
* and deallocating when the memory address needs to be reused.
*
* CAUTION: The behavior is undefined if the memory address is deallocated and the method is called
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector<uint8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
Expand Down
6 changes: 3 additions & 3 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
* Signature: (J[[FJJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_JNICommons
Expand Down
14 changes: 12 additions & 2 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,39 @@
#include "commons.h"

jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<float> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<float>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
}

if (appendJ == JNI_FALSE) {
vect->clear();
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect);

return (jlong) vect;
}

jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<uint8_t> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<uint8_t>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddressJ);
}

if (appendJ == JNI_FALSE) {
vect->clear();
}

int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect);

Expand Down
1 change: 0 additions & 1 deletion jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &indexService);
delete reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
} catch (...) {
// NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH!
jniUtil.CatchCppExceptionAndThrowJava(env);
Expand Down
8 changes: 4 additions & 4 deletions jni/src/org_opensearch_knn_jni_JNICommons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {


JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ)

{
try {
return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ);
return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (long)memoryAddressJ;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ)

{
try {
return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ);
return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
114 changes: 111 additions & 3 deletions jni/tests/commons_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(CommonsTests, BasicAssertions) {
testing::NiceMock<test_util::MockJNIUtil> mockJNIUtil;

jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0,
reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim));
reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim), true);
ASSERT_NE(memoryAddress, 0);
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddress);
ASSERT_EQ(vect->size(), data.size() * dim);
Expand All @@ -48,12 +48,13 @@ TEST(CommonsTests, BasicAssertions) {
}
data2.push_back(vector);
memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim));
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim), true);
ASSERT_NE(memoryAddress, 0);
ASSERT_EQ(memoryAddress, oldMemoryAddress);
vect = reinterpret_cast<std::vector<float>*>(memoryAddress);
int currentIndex = 0;
ASSERT_EQ(vect->size(), totalNumberOfVector*dim);
std::cout << vect->size() + "\n";
ASSERT_EQ(vect->size(), totalNumberOfVector * dim);
ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim);

// Validate if all vectors data are at correct location
Expand All @@ -70,6 +71,113 @@ TEST(CommonsTests, BasicAssertions) {
currentIndex++;
}
}

// test append == true
std::vector<std::vector<float>> data3;
std::vector<float> vecto3;
for(int j = 0 ; j < dim ; j ++) {
vecto3.push_back((float)j);
}
data3.push_back(vecto3);
memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data3), (jlong)(totalNumberOfVector * dim), false);
ASSERT_NE(memoryAddress, 0);
ASSERT_EQ(memoryAddress, oldMemoryAddress);
vect = reinterpret_cast<std::vector<float>*>(memoryAddress);

ASSERT_EQ(vect->size(), dim); //Since we just added 1 vector
ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); //This is the initial capacity allocated

currentIndex = 0;
for(auto & i : data3) {
for(float j : i) {
ASSERT_FLOAT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

// Check that freeing vector data works
knn_jni::commons::freeVectorData(memoryAddress);
}

TEST(StoreByteVectorTest, BasicAssertions) {
long dim = 3;
long totalNumberOfVector = 5;
std::vector<std::vector<uint8_t>> data;
for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) {
std::vector<uint8_t> vector;
for(int j = 0 ; j < dim ; j ++) {
vector.push_back((uint8_t)j);
}
data.push_back(vector);
}
JNIEnv *jniEnv = nullptr;

testing::NiceMock<test_util::MockJNIUtil> mockJNIUtil;

jlong memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, (jlong)0,
reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim), true);
ASSERT_NE(memoryAddress, 0);
auto *vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddress);
ASSERT_EQ(vect->size(), data.size() * dim);
ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim);

// Check by inserting more vectors at same memory location
jlong oldMemoryAddress = memoryAddress;
std::vector<std::vector<uint8_t>> data2;
std::vector<uint8_t> vector;
for(int j = 0 ; j < dim ; j ++) {
vector.push_back((uint8_t)j);
}
data2.push_back(vector);
memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim), true);
ASSERT_NE(memoryAddress, 0);
ASSERT_EQ(memoryAddress, oldMemoryAddress);
vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddress);
int currentIndex = 0;
ASSERT_EQ(vect->size(), totalNumberOfVector*dim);
ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim);

// Validate if all vectors data are at correct location
for(auto & i : data) {
for(uint8_t j : i) {
ASSERT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

for(auto & i : data2) {
for(uint8_t j : i) {
ASSERT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

// test append == true
std::vector<std::vector<uint8_t>> data3;
std::vector<uint8_t> vecto3;
for(int j = 0 ; j < dim ; j ++) {
vecto3.push_back((uint8_t)j);
}
data3.push_back(vecto3);
memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data3), (jlong)(totalNumberOfVector * dim), false);
ASSERT_NE(memoryAddress, 0);
ASSERT_EQ(memoryAddress, oldMemoryAddress);
vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddress);

ASSERT_EQ(vect->size(), dim);
ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim);

currentIndex = 0;
for(auto & i : data3) {
for(uint8_t j : i) {
ASSERT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

// Check that freeing vector data works
knn_jni::commons::freeVectorData(memoryAddress);
}
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

import lombok.experimental.UtilityClass;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.indices.ModelMetadata;

import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.indices.ModelUtil.getModelMetadata;

@UtilityClass
public class FieldInfoExtractor {

public static KNNEngine extractKNNEngine(final FieldInfo field) {
final ModelMetadata modelMetadata = getModelMetadata(field.attributes().get(MODEL_ID));
if (modelMetadata != null) {
return modelMetadata.getKnnEngine();
}
final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
return KNNEngine.getEngine(engineName);
}

public static VectorDataType extractVectorDataType(final FieldInfo field) {
String vectorDataTypeString = field.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD);
if (vectorDataTypeString == null) {
final ModelMetadata modelMetadata = getModelMetadata(field.attributes().get(MODEL_ID));
if (modelMetadata != null) {
VectorDataType vectorDataType = modelMetadata.getVectorDataType();
vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue();
}
}
return vectorDataTypeString != null ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT;
}
}
18 changes: 17 additions & 1 deletion src/main/java/org/opensearch/knn/common/KNNVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

package org.opensearch.knn.common;

import java.util.Objects;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;

import java.util.ArrayList;
import java.util.Objects;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorUtil {
/**
Expand Down Expand Up @@ -42,4 +44,18 @@ public static boolean isZeroVector(float[] vector) {
}
return true;
}

/**
* Creates an int overflow safe arraylist. If there is an overflow it will create a list with default initial size
* @param batchSize size to allocate
* @return an arrayList
*/
public static <T> ArrayList<T> createArrayList(long batchSize) {
try {
return new ArrayList<>(Math.toIntExact(batchSize));
} catch (Exception exception) {
// No-op
}
return new ArrayList<>();
}
}
Loading

0 comments on commit af635b3

Please sign in to comment.