Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Introducing a loading layer in FAISS native engine. #2179

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD)
### Features
### Enhancements
* Introducing a loading layer in FAISS [#2033](https://github.com/opensearch-project/k-NN/issues/2033)
### Bug Fixes
* Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946)
### Infrastructure
Expand Down
1 change: 1 addition & 0 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ if ("${WIN32}" STREQUAL "")
tests/nmslib_wrapper_unit_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
tests/faiss_stream_support_test.cpp
tests/faiss_index_service_test.cpp
)

Expand Down
136 changes: 136 additions & 0 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H

#include "faiss/impl/io.h"
#include "jni_util.h"

#include <jni.h>
#include <stdexcept>
#include <iostream>
#include <cstring>

namespace knn_jni {
namespace stream {

/**
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/

class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface,
JNIEnv *_env,
jobject _indexInput)
: jni_interface(_jni_interface),
env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
const auto readBytes =
jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
auto primitiveArray =
(jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIUtilInterface *jni_interface;
JNIEnv *env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
}; // class NativeEngineIndexInputMediator



/**
* A glue component inheriting IOReader to be passed down to Faiss library.
* This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput.
*/
class FaissOpenSearchIOReader final : public faiss::IOReader {
public:
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator)
: faiss::IOReader(),
mediator(_mediator) {
name = "FaissOpenSearchIOReader";
}

size_t operator()(void *ptr, size_t size, size_t nitems) final {
const auto readBytes = size * nitems;
if (readBytes > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(readBytes, (uint8_t *) ptr);
}
return nitems;
}

int filedescriptor() final {
throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader.");
}

private:
NativeEngineIndexInputMediator *mediator;
}; // class FaissOpenSearchIOReader



}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
10 changes: 10 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Loads an index with a reader implemented IOReader
//
// Returns a pointer of the loaded index
jlong LoadIndexWithStream(faiss::IOReader* ioReader);

// Load a binary index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Loads a binary index with a reader implemented IOReader
//
// Returns a pointer of the loaded index
jlong LoadBinaryIndexWithStream(faiss::IOReader* ioReader);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

Expand Down
102 changes: 61 additions & 41 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
namespace knn_jni {

// Interface for making calls to JNI
class JNIUtilInterface {
public:
struct JNIUtilInterface {
// -------------------------- EXCEPTION HANDLING ----------------------------
// Takes the name of a Java exception type and a message and throws the corresponding exception
// to the JVM
Expand Down Expand Up @@ -127,56 +126,77 @@ namespace knn_jni {

virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0;

virtual jobject GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) = 0;

virtual jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) = 0;

virtual jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) = 0;

virtual jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) = 0;

virtual void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) = 0;

virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0;

virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0;

// --------------------------------------------------------------------------
};

jobject GetJObjectFromMapOrThrow(std::unordered_map<std::string, jobject> map, std::string key);

// Class that implements JNIUtilInterface methods
class JNIUtil: public JNIUtilInterface {
class JNIUtil final : public JNIUtilInterface {
public:
// Initialize and Uninitialize methods are used for caching/cleaning up Java classes and methods
void Initialize(JNIEnv* env);
void Uninitialize(JNIEnv* env);

void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = "");
void HasExceptionInStack(JNIEnv* env);
void HasExceptionInStack(JNIEnv* env, const std::string& message);
void CatchCppExceptionAndThrowJava(JNIEnv* env);
jclass FindClass(JNIEnv * env, const std::string& className);
jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName);
std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString);
std::unordered_map<std::string, jobject> ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ);
std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ);
int ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ);
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim);
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ);
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ);
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ);

void DeleteLocalRef(JNIEnv *env, jobject obj);
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy);
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy);
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy);
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy);
jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index);
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance);
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init);
jbyteArray NewByteArray(JNIEnv *env, jsize len);
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode);
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode);
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode);
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<int8_t> *vect);
void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = "") final;
void HasExceptionInStack(JNIEnv* env) final;
void HasExceptionInStack(JNIEnv* env, const std::string& message) final;
void CatchCppExceptionAndThrowJava(JNIEnv* env) final;
jclass FindClass(JNIEnv * env, const std::string& className) final;
jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName) final;
std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString) final;
std::unordered_map<std::string, jobject> ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ) final;
std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ) final;
int ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ) final;
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) final;
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) final;
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) final;
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) final;
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) final;
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) final;
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) final;
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) final;
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) final;

void DeleteLocalRef(JNIEnv *env, jobject obj) final;
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy) final;
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy) final;
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) final;
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) final;
jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) final;
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) final;
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init) final;
jbyteArray NewByteArray(JNIEnv *env, jsize len) final;
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode) final;
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode) final;
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) final;
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) final;
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) final;
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) final;
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect) final;
void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect) final;
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<int8_t> *vect) final;
jobject GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) final;
jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final;
jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final;
void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final;
void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final;

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndexWithStream
* Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream
(JNIEnv *, jclass, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndex
Expand All @@ -136,6 +144,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndexWithStream
* Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream
(JNIEnv *, jclass, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
Expand Down
28 changes: 28 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,20 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI
return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadIndexWithStream(faiss::IOReader* ioReader) {
if (ioReader == nullptr) [[unlikely]] {
throw std::runtime_error("IOReader cannot be null");
}

faiss::Index* indexReader =
faiss::read_index(ioReader,
faiss::IO_FLAG_READ_ONLY
| faiss::IO_FLAG_PQ_SKIP_SDC_TABLE
| faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);

return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand All @@ -436,6 +450,20 @@ jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUti
return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadBinaryIndexWithStream(faiss::IOReader* ioReader) {
if (ioReader == nullptr) [[unlikely]] {
throw std::runtime_error("IOReader cannot be null");
}

faiss::IndexBinary* indexReader =
faiss::read_index_binary(ioReader,
faiss::IO_FLAG_READ_ONLY
| faiss::IO_FLAG_PQ_SKIP_SDC_TABLE
| faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);

return (jlong) indexReader;
}

bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) {
auto * index = reinterpret_cast<faiss::Index*>(indexPointerJ);
return isIndexIVFPQL2(index);
Expand Down
Loading
Loading