Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add binary format support with IVF method in Faiss Engine #1863

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
* Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826)
* Add painless script support for hamming with binary vector data type [#1839](https://github.com/opensearch-project/k-NN/pull/1839)
* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784)
### Enhancements
* Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825)
### Bug Fixes
Expand Down
13 changes: 13 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ namespace knn_jni {
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
//
// Return a pointer to the loaded index
Expand Down Expand Up @@ -96,6 +102,13 @@ namespace knn_jni {
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with
// the vector of floats located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createBinaryIndexFromTemplate
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndex
Expand Down Expand Up @@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: trainBinaryIndex
* Signature: (Ljava/util/Map;IJ)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectors
Expand Down
133 changes: 133 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Train a binary index with data provided
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

Expand Down Expand Up @@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, reinterpret_cast<const uint8_t*>(inputVectors->data()), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand Down Expand Up @@ -634,6 +707,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
return ret;
}

jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str()));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
if(!indexWriter->is_trained) {
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
return ret;
}

faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
if (spaceType == knn_jni::L2) {
return faiss::METRIC_L2;
Expand Down Expand Up @@ -692,6 +816,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
indexIvf->make_direct_map();
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
}
}

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);
Expand Down
28 changes: 28 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
jint dimJ,
jstring indexPathJ,
jbyteArray templateIndexJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
{
try {
Expand Down Expand Up @@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
return nullptr;
}

JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls,
jobject parametersJ,
jint dimensionJ,
jlong trainVectorsPointerJ)
{
try {
return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class KNNConstants {
public static final String SEARCH_SIZE_PARAMETER = "search_size";

public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";
Expand Down
Loading
Loading