Skip to content

Commit

Permalink
Add IVF changes to support Faiss byte vector (#2002)
Browse files Browse the repository at this point in the history
* Add HNSW changes to support Faiss byte vector

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add IVF changes to support Faiss byte vector

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

---------

Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda authored Aug 30, 2024
1 parent 9dbe7de commit 2b303d9
Show file tree
Hide file tree
Showing 21 changed files with 626 additions and 92 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
### Bug Fixes
Expand Down
13 changes: 13 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ namespace knn_jni {
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create a index with ids and byte vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
//
// Return a pointer to the loaded index
Expand Down Expand Up @@ -110,6 +116,13 @@ namespace knn_jni {
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty byte index defined by the values in the Java map, parametersJ. Train the index with
// the byte vectors located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createByteIndexFromTemplate
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndex
Expand Down Expand Up @@ -216,6 +224,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: trainByteIndex
* Signature: (Ljava/util/Map;IJ)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectors
Expand Down
157 changes: 157 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,96 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<int8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);

if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());

// Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike.
// Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255
int batchSize = 1000;
std::vector <float> inputFloatVectors(batchSize * dim);
std::vector <int64_t> floatVectorsIds(batchSize);
int id = 0;
auto iter = inputVectors->begin();

for (int id = 0; id < numVectors; id += batchSize) {
if (numVectors - id < batchSize) {
batchSize = numVectors - id;
}

for (int i = 0; i < batchSize; ++i) {
floatVectorsIds[i] = ids[id + i];
for (int j = 0; j < dim; ++j, ++iter) {
inputFloatVectors[i * dim + j] = static_cast<float>(*iter);
}
}
idMap.add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data());
}

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand Down Expand Up @@ -782,6 +872,73 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
return ret;
}

jbyteArray knn_jni::faiss_wrapper::TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get());
jniUtil->DeleteLocalRef(env, subParametersJ);
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<int8_t>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;

auto iter = trainingVectorsPointerCpp->begin();
std::vector <float> trainingFloatVectors(numVectors * dimensionJ);
for(int i=0; i < numVectors * dimensionJ; ++i, ++iter) {
trainingFloatVectors[i] = static_cast<float>(*iter);
}

if(!indexWriter->is_trained) {
InternalTrainIndex(indexWriter.get(), numVectors, trainingFloatVectors.data());
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
return ret;
}


faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
if (spaceType == knn_jni::L2) {
return faiss::METRIC_L2;
Expand Down
28 changes: 28 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
jint dimJ,
jstring indexPathJ,
jbyteArray templateIndexJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
{
try {
Expand Down Expand Up @@ -335,6 +350,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar
return nullptr;
}

JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex(JNIEnv * env, jclass cls,
jobject parametersJ,
jint dimensionJ,
jlong trainVectorsPointerJ)
{
try {
return knn_jni::faiss_wrapper::TrainByteIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
Expand Down
81 changes: 81 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,55 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
std::remove(indexPath.c_str());
}

TEST(FaissCreateByteIndexFromTemplateTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
auto *vectors = new std::vector<int8_t>();
int dim = 8;
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors->push_back(test_util::RandomInt(-128, 127));
}
}

std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss");
faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,SQ8_direct_signed";

std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get());

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors->size()));

std::string spaceType = knn_jni::L2;
std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;

knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)vectors, dim, (jstring)&indexPath,
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)),
(jobject) &parametersMap
);

// Make sure index can be loaded
std::unique_ptr<faiss::Index> index(test_util::FaissLoadIndex(indexPath));

// Clean up
std::remove(indexPath.c_str());
}

TEST(FaissLoadIndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 100;
Expand Down Expand Up @@ -717,6 +766,38 @@ TEST(FaissTrainIndexTest, BasicAssertions) {
ASSERT_TRUE(trainedIndex->is_trained);
}

TEST(FaissTrainByteIndexTest, BasicAssertions) {
// Define the index configuration
int dim = 2;
std::string spaceType = knn_jni::L2;
std::string index_description = "IVF4,SQ8_direct_signed";

std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;
parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject) &index_description;

// Define training data
int numTrainingVectors = 256;
std::vector<int8_t> trainingVectors = test_util::RandomByteVectors(dim, numTrainingVectors, -128, 127);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

// Perform training
std::unique_ptr<std::vector<uint8_t>> trainedIndexSerialization(
reinterpret_cast<std::vector<uint8_t> *>(
knn_jni::faiss_wrapper::TrainByteIndex(
&mockJNIUtil, jniEnv, (jobject) &parametersMap, dim,
reinterpret_cast<jlong>(&trainingVectors))));

std::unique_ptr<faiss::Index> trainedIndex(
test_util::FaissLoadFromSerializedIndex(trainedIndexSerialization.get()));

// Confirm that training succeeded
ASSERT_TRUE(trainedIndex->is_trained);
}

TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
Expand Down
8 changes: 8 additions & 0 deletions jni/tests/test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,14 @@ std::vector<float> test_util::RandomVectors(int dim, int64_t numVectors, float m
return vectors;
}

std::vector<int8_t> test_util::RandomByteVectors(int dim, int64_t numVectors, int min, int max) {
std::vector<int8_t> vectors(dim*numVectors);
for (int64_t i = 0; i < dim*numVectors; i++) {
vectors[i] = test_util::RandomInt(min, max);
}
return vectors;
}

std::vector<int64_t> test_util::Range(int64_t numElements) {
std::vector<int64_t> rangeVector(numElements);
for (int64_t i = 0; i < numElements; i++) {
Expand Down
2 changes: 2 additions & 0 deletions jni/tests/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ namespace test_util {

std::vector<float> RandomVectors(int dim, int64_t numVectors, float min, float max);

std::vector<int8_t> RandomByteVectors(int dim, int64_t numVectors, int min, int max);

std::vector<int64_t> Range(int64_t numElements);

// returns the number of 64 bit words it would take to hold numBits
Expand Down
Loading

0 comments on commit 2b303d9

Please sign in to comment.