Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add IVF changes to support Faiss byte vector #2017

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
### Bug Fixes
Expand Down
13 changes: 13 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ namespace knn_jni {
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create a index with ids and byte vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
//
// Return a pointer to the loaded index
Expand Down Expand Up @@ -110,6 +116,13 @@ namespace knn_jni {
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty byte index defined by the values in the Java map, parametersJ. Train the index with
// the byte vectors located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createByteIndexFromTemplate
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndex
Expand Down Expand Up @@ -216,6 +224,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: trainByteIndex
* Signature: (Ljava/util/Map;IJ)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectors
Expand Down
157 changes: 157 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,96 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<int8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);

if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());

// Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike.
// Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255
int batchSize = 1000;
std::vector <float> inputFloatVectors(batchSize * dim);
std::vector <int64_t> floatVectorsIds(batchSize);
int id = 0;
auto iter = inputVectors->begin();

for (int id = 0; id < numVectors; id += batchSize) {
if (numVectors - id < batchSize) {
batchSize = numVectors - id;
}

for (int i = 0; i < batchSize; ++i) {
floatVectorsIds[i] = ids[id + i];
for (int j = 0; j < dim; ++j, ++iter) {
inputFloatVectors[i * dim + j] = static_cast<float>(*iter);
}
}
idMap.add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data());
}

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand Down Expand Up @@ -782,6 +872,73 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
return ret;
}

jbyteArray knn_jni::faiss_wrapper::TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get());
jniUtil->DeleteLocalRef(env, subParametersJ);
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<int8_t>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;

auto iter = trainingVectorsPointerCpp->begin();
std::vector <float> trainingFloatVectors(numVectors * dimensionJ);
for(int i=0; i < numVectors * dimensionJ; ++i, ++iter) {
trainingFloatVectors[i] = static_cast<float>(*iter);
}

if(!indexWriter->is_trained) {
InternalTrainIndex(indexWriter.get(), numVectors, trainingFloatVectors.data());
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
return ret;
}


faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
if (spaceType == knn_jni::L2) {
return faiss::METRIC_L2;
Expand Down
28 changes: 28 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
jint dimJ,
jstring indexPathJ,
jbyteArray templateIndexJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
{
try {
Expand Down Expand Up @@ -335,6 +350,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar
return nullptr;
}

JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex(JNIEnv * env, jclass cls,
jobject parametersJ,
jint dimensionJ,
jlong trainVectorsPointerJ)
{
try {
return knn_jni::faiss_wrapper::TrainByteIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
Expand Down
81 changes: 81 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,55 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
std::remove(indexPath.c_str());
}

TEST(FaissCreateByteIndexFromTemplateTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
auto *vectors = new std::vector<int8_t>();
int dim = 8;
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors->push_back(test_util::RandomInt(-128, 127));
}
}

std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss");
faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,SQ8_direct_signed";

std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get());

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors->size()));

std::string spaceType = knn_jni::L2;
std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;

knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)vectors, dim, (jstring)&indexPath,
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)),
(jobject) &parametersMap
);

// Make sure index can be loaded
std::unique_ptr<faiss::Index> index(test_util::FaissLoadIndex(indexPath));

// Clean up
std::remove(indexPath.c_str());
}

TEST(FaissLoadIndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 100;
Expand Down Expand Up @@ -717,6 +766,38 @@ TEST(FaissTrainIndexTest, BasicAssertions) {
ASSERT_TRUE(trainedIndex->is_trained);
}

TEST(FaissTrainByteIndexTest, BasicAssertions) {
// Define the index configuration
int dim = 2;
std::string spaceType = knn_jni::L2;
std::string index_description = "IVF4,SQ8_direct_signed";

std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;
parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject) &index_description;

// Define training data
int numTrainingVectors = 256;
std::vector<int8_t> trainingVectors = test_util::RandomByteVectors(dim, numTrainingVectors, -128, 127);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

// Perform training
std::unique_ptr<std::vector<uint8_t>> trainedIndexSerialization(
reinterpret_cast<std::vector<uint8_t> *>(
knn_jni::faiss_wrapper::TrainByteIndex(
&mockJNIUtil, jniEnv, (jobject) &parametersMap, dim,
reinterpret_cast<jlong>(&trainingVectors))));

std::unique_ptr<faiss::Index> trainedIndex(
test_util::FaissLoadFromSerializedIndex(trainedIndexSerialization.get()));

// Confirm that training succeeded
ASSERT_TRUE(trainedIndex->is_trained);
}

TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
Expand Down
8 changes: 8 additions & 0 deletions jni/tests/test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,14 @@ std::vector<float> test_util::RandomVectors(int dim, int64_t numVectors, float m
return vectors;
}

std::vector<int8_t> test_util::RandomByteVectors(int dim, int64_t numVectors, int min, int max) {
std::vector<int8_t> vectors(dim*numVectors);
for (int64_t i = 0; i < dim*numVectors; i++) {
vectors[i] = test_util::RandomInt(min, max);
}
return vectors;
}

std::vector<int64_t> test_util::Range(int64_t numElements) {
std::vector<int64_t> rangeVector(numElements);
for (int64_t i = 0; i < numElements; i++) {
Expand Down
2 changes: 2 additions & 0 deletions jni/tests/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ namespace test_util {

std::vector<float> RandomVectors(int dim, int64_t numVectors, float min, float max);

std::vector<int8_t> RandomByteVectors(int dim, int64_t numVectors, int min, int max);

std::vector<int64_t> Range(int64_t numElements);

// returns the number of 64 bit words it would take to hold numBits
Expand Down
Loading
Loading