Skip to content

Commit

Permalink
Add binary format support for IVF train model and create index api
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 8, 2024
1 parent ab0a2d1 commit 43f1bf0
Show file tree
Hide file tree
Showing 14 changed files with 307 additions and 26 deletions.
24 changes: 24 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "faiss_methods.h"
#include <memory>

namespace faiss {
struct VectorIOReader;
}

namespace knn_jni {
namespace faiss_wrapper {

Expand Down Expand Up @@ -61,6 +65,16 @@ class IndexService {
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);

virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::VectorIOReader vectorIoReader,
std::vector<int64_t> idVector,
int numVectors,
std::vector<float> *inputVectors,
std::string& indexPathCpp);

virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
Expand Down Expand Up @@ -103,6 +117,16 @@ class BinaryIndexService : public IndexService {
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;

virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::VectorIOReader vectorIoReader,
std::vector<int64_t> idVector,
int numVectors,
std::vector<float> *inputVectors,
std::string& indexPathCpp) override;

virtual ~BinaryIndexService() = default;
};

Expand Down
2 changes: 2 additions & 0 deletions jni/include/faiss_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class FaissMethods {
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
virtual void writeIndex(const faiss::Index* idx, const char* fname);
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
virtual faiss::Index* readIndex(const char* indexPath);
virtual faiss::IndexBinary* readIndexBinary(const char* indexPath);
virtual ~FaissMethods() = default;
};

Expand Down
9 changes: 8 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace knn_jni {
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);
jobject parametersJ, IndexService* indexService);

// Load an index from indexPathJ into memory.
//
Expand Down Expand Up @@ -102,6 +102,13 @@ namespace knn_jni {
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with
// the vector of floats located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
Expand Down
17 changes: 17 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);


/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand All @@ -43,6 +44,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createBinaryIndexFromTemplate
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndex
Expand Down Expand Up @@ -147,6 +156,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: trainBinaryIndex
* Signature: (Ljava/util/Map;IJ)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectors
Expand Down
58 changes: 57 additions & 1 deletion jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>
#include <memory>
#include <type_traits>
#include <faiss/impl/io.h>

namespace knn_jni {
namespace faiss_wrapper {
Expand Down Expand Up @@ -106,6 +107,31 @@ void IndexService::createIndex(
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

void IndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::VectorIOReader vectorIoReader,
std::vector<int64_t> idVector,
int numVectors,
std::vector<float> *inputVectors,
std::string& indexPathCpp) {
// Read vectors from memory address
// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());

idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;

faiss::write_index(&idMap, indexPathCpp.c_str());
}

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void BinaryIndexService::createIndex(
Expand Down Expand Up @@ -160,5 +186,35 @@ void BinaryIndexService::createIndex(
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

void BinaryIndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::VectorIOReader vectorIoReader,
std::vector<int64_t> idVector,
int numVectors,
std::vector<float> *inputVectors,
std::string& indexPathCpp) {
// Read vectors from memory address
// Create faiss index
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(dynamic_cast<faiss::IndexBinary*>(faiss::read_index(&vectorIoReader, 0)));

// faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());

// idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());
std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get()));

auto* vectorData = reinterpret_cast<uint8_t*>(inputVectors->data());
idMap->add_with_ids(numVectors, vectorData, idVector.data());

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;

// Write the index to disk
faissMethods->writeIndexBinary(idMap.get(), indexPathCpp.c_str());
}

} // namespace faiss_wrapper
} // namesapce knn_jni
} // namesapce knn_jni
9 changes: 9 additions & 0 deletions jni/src/faiss_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ faiss::IndexIDMapTemplate<faiss::IndexBinary>* FaissMethods::indexBinaryIdMap(fa
void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) {
faiss::write_index(idx, fname);
}

void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) {
faiss::write_index_binary(idx, fname);
}

faiss::Index* FaissMethods::readIndex(const char* indexPath) {
return faiss::read_index(indexPath);
}

faiss::IndexBinary* FaissMethods::readIndexBinary(const char* indexPath) {
return reinterpret_cast<faiss::IndexBinary*>(faiss::read_index_binary(indexPath));
}

} // namespace faiss_wrapper
} // namesapce knn_jni
97 changes: 86 additions & 11 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Train a binary index with data provided
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

Expand Down Expand Up @@ -152,13 +155,15 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
}
// end parameters to pass

std::cout << "Index description in CreateIndex: " << indexDescriptionCpp << std::endl;

// Create index
indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp);
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
jbyteArray templateIndexJ, jobject parametersJ, IndexService* indexService) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}
Expand Down Expand Up @@ -187,7 +192,6 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
Expand All @@ -207,20 +211,28 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());


std::vector<uint8_t> indexBytes(indexBytesJ, indexBytesJ + indexBytesCount);
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Convert IDs to vector
auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);

// Index path
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));

// Template index path
std::string templateIndexPathCpp(reinterpret_cast<char*>(indexBytes.data()), indexBytes.size());

// Create index from template
indexService->createIndexFromTemplate(jniUtil, env, vectorIoReader, idVector, numVectors, inputVectors, indexPathCpp);

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
Expand Down Expand Up @@ -568,6 +580,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl;
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));

Expand Down Expand Up @@ -617,6 +630,58 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
return ret;
}

jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl;
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str()));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
if(!indexWriter->is_trained) {
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
return ret;
}

faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
if (spaceType == knn_jni::L2) {
return faiss::METRIC_L2;
Expand Down Expand Up @@ -675,6 +740,16 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
std::cout << "Index is IVFBinary" << std::endl;
indexIvf->make_direct_map();
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
}
}

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);
Expand Down
Loading

0 comments on commit 43f1bf0

Please sign in to comment.