Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jni interface to use a binary hnsw index with faiss #1778

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ set(TARGET_LIBS "") # Libs to be installed

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED True)

option(CONFIG_FAISS "Configure faiss library build when this is on")
option(CONFIG_NMSLIB "Configure nmslib library build when this is on")
option(CONFIG_TEST "Configure tests when this is on")
Expand Down Expand Up @@ -112,6 +111,8 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_util.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_index_service.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_methods.cpp
)
target_link_libraries(${TARGET_LIB_FAISS} ${TARGET_LINK_FAISS_LIB} ${TARGET_LIB_UTIL} OpenMP::OpenMP_CXX)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
Expand Down Expand Up @@ -153,6 +154,7 @@ if ("${WIN32}" STREQUAL "")
tests/nmslib_wrapper_unit_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
tests/faiss_index_service_test.cpp
)

target_link_libraries(
Expand Down
24 changes: 23 additions & 1 deletion jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,24 @@ namespace knn_jni {
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
* @return memory address of std::vector<float> where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector<uint8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
Expand All @@ -34,6 +48,14 @@ namespace knn_jni {
*/
void freeVectorData(jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeByteVectorData(jlong);

/**
* Extracts query time efSearch from method parameters
**/
Expand Down
113 changes: 113 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-License-Identifier: Apache-2.0
//
// The OpenSearch Contributors require contributions made to
// this file be licensed under the Apache-2.0 license or a
// compatible open source license.
//
// Modifications Copyright OpenSearch Contributors. See
// GitHub history for details.

/**
* This file contains classes for index operations which are free of JNI
*/

#ifndef OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
#define OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H

#include <jni.h>
#include "faiss/MetricType.h"
#include "jni_util.h"
#include "faiss_methods.h"
#include <memory>

namespace knn_jni {
namespace faiss_wrapper {


/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
*/
class IndexService {
public:
IndexService(std::unique_ptr<FaissMethods> faissMethods);
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters

/**
* Create index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);
virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
};

/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
*/
class BinaryIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);
/**
* Create binary index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;
virtual ~BinaryIndexService() = default;
};

}
}


#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
42 changes: 42 additions & 0 deletions jni/include/faiss_methods.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-License-Identifier: Apache-2.0
//
// The OpenSearch Contributors require contributions made to
// this file be licensed under the Apache-2.0 license or a
// compatible open source license.
//
// Modifications Copyright OpenSearch Contributors. See
// GitHub history for details.

#ifndef OPENSEARCH_KNN_FAISS_METHODS_H
#define OPENSEARCH_KNN_FAISS_METHODS_H

#include "faiss/Index.h"
#include "faiss/IndexBinary.h"
#include "faiss/IndexIDMap.h"
#include "faiss/index_io.h"

namespace knn_jni {
namespace faiss_wrapper {

/**
* A class having wrapped faiss methods
*
* This class helps to mock faiss methods during unit test
*/
class FaissMethods {
public:
FaissMethods() = default;
virtual faiss::Index* indexFactory(int d, const char* description, faiss::MetricType metric);
virtual faiss::IndexBinary* indexBinaryFactory(int d, const char* description);
virtual faiss::IndexIDMapTemplate<faiss::Index>* indexIdMap(faiss::Index* index);
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
virtual void writeIndex(const faiss::Index* idx, const char* fname);
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
virtual ~FaissMethods() = default;
};

} //namespace faiss_wrapper
} //namespace knn_jni


#endif //OPENSEARCH_KNN_FAISS_METHODS_H
14 changes: 13 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
#define OPENSEARCH_KNN_FAISS_WRAPPER_H

#include "jni_util.h"
#include "faiss_index_service.h"
#include <jni.h>

namespace knn_jni {
namespace faiss_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);
jstring indexPathJ, jobject parametersJ, IndexService* indexService);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
Expand All @@ -33,6 +34,11 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Load a binary index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

Expand Down Expand Up @@ -68,6 +74,12 @@ namespace knn_jni {
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Execute a query against the binary index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

Expand Down
7 changes: 7 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ namespace knn_jni {

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<uint8_t> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

Expand All @@ -79,6 +81,8 @@ namespace knn_jni {
// ------------------------------ MISC HELPERS ------------------------------
virtual int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) = 0;

virtual int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) = 0;

virtual int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) = 0;

virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;
Expand Down Expand Up @@ -146,6 +150,7 @@ namespace knn_jni {
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim);
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ);
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
Expand All @@ -168,6 +173,7 @@ namespace knn_jni {
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand All @@ -193,6 +199,7 @@ namespace knn_jni {
extern const std::string COSINESIMIL;
extern const std::string INNER_PRODUCT;
extern const std::string NEG_DOT_PRODUCT;
extern const std::string HAMMING_BIT;

extern const std::string NPROBES;
extern const std::string COARSE_QUANTIZER;
Expand Down
29 changes: 27 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifdef __cplusplus
extern "C" {
#endif

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndex
Expand All @@ -26,6 +27,14 @@ extern "C" {
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createBinaryIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand All @@ -42,6 +51,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndex
* Signature: (Ljava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
Expand Down Expand Up @@ -69,19 +86,27 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex
* Signature: (J[FI[Ljava/util/MapI)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FILjava/util/Map[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndexWithFilter
* Signature: (J[FI[JLjava/util/MapI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FILjava/util/Map[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryBIndexWithFilter
* Signature: (J[BILjava/util/Map[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
Loading
Loading