Skip to content

Commit

Permalink
Add jni interface to use a binary hnsw index with faiss (#1778)
Browse files Browse the repository at this point in the history
* Add jni interface to use a binary hnsw index with faiss (#1747)

Signed-off-by: Heemin Kim <[email protected]>

* Fix memory leak on test code (#1776)

Signed-off-by: Heemin Kim <[email protected]>

---------

Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 authored and luyuncheng committed Jul 7, 2024
1 parent 7e95e16 commit 083f668
Show file tree
Hide file tree
Showing 36 changed files with 1,637 additions and 106 deletions.
4 changes: 3 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ set(TARGET_LIBS "") # Libs to be installed

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED True)

option(CONFIG_FAISS "Configure faiss library build when this is on")
option(CONFIG_NMSLIB "Configure nmslib library build when this is on")
option(CONFIG_TEST "Configure tests when this is on")
Expand Down Expand Up @@ -112,6 +111,8 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_util.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_index_service.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_methods.cpp
)
target_link_libraries(${TARGET_LIB_FAISS} ${TARGET_LINK_FAISS_LIB} ${TARGET_LIB_UTIL} OpenMP::OpenMP_CXX)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
Expand Down Expand Up @@ -153,6 +154,7 @@ if ("${WIN32}" STREQUAL "")
tests/nmslib_wrapper_unit_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
tests/faiss_index_service_test.cpp
)

target_link_libraries(
Expand Down
24 changes: 23 additions & 1 deletion jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,24 @@ namespace knn_jni {
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
* @return memory address of std::vector<float> where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector<uint8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
Expand All @@ -34,6 +48,14 @@ namespace knn_jni {
*/
void freeVectorData(jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeByteVectorData(jlong);

/**
* Extracts query time efSearch from method parameters
**/
Expand Down
113 changes: 113 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-License-Identifier: Apache-2.0
//
// The OpenSearch Contributors require contributions made to
// this file be licensed under the Apache-2.0 license or a
// compatible open source license.
//
// Modifications Copyright OpenSearch Contributors. See
// GitHub history for details.

/**
* This file contains classes for index operations which are free of JNI
*/

#ifndef OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
#define OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H

#include <jni.h>
#include "faiss/MetricType.h"
#include "jni_util.h"
#include "faiss_methods.h"
#include <memory>

namespace knn_jni {
namespace faiss_wrapper {


/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
*/
class IndexService {
public:
IndexService(std::unique_ptr<FaissMethods> faissMethods);
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters

/**
* Create index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);
virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
};

/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
*/
class BinaryIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);
/**
* Create binary index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;
virtual ~BinaryIndexService() = default;
};

}
}


#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
42 changes: 42 additions & 0 deletions jni/include/faiss_methods.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-License-Identifier: Apache-2.0
//
// The OpenSearch Contributors require contributions made to
// this file be licensed under the Apache-2.0 license or a
// compatible open source license.
//
// Modifications Copyright OpenSearch Contributors. See
// GitHub history for details.

#ifndef OPENSEARCH_KNN_FAISS_METHODS_H
#define OPENSEARCH_KNN_FAISS_METHODS_H

#include "faiss/Index.h"
#include "faiss/IndexBinary.h"
#include "faiss/IndexIDMap.h"
#include "faiss/index_io.h"

namespace knn_jni {
namespace faiss_wrapper {

/**
* A class having wrapped faiss methods
*
* This class helps to mock faiss methods during unit test
*/
class FaissMethods {
public:
FaissMethods() = default;
virtual faiss::Index* indexFactory(int d, const char* description, faiss::MetricType metric);
virtual faiss::IndexBinary* indexBinaryFactory(int d, const char* description);
virtual faiss::IndexIDMapTemplate<faiss::Index>* indexIdMap(faiss::Index* index);
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
virtual void writeIndex(const faiss::Index* idx, const char* fname);
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
virtual ~FaissMethods() = default;
};

} //namespace faiss_wrapper
} //namespace knn_jni


#endif //OPENSEARCH_KNN_FAISS_METHODS_H
14 changes: 13 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
#define OPENSEARCH_KNN_FAISS_WRAPPER_H

#include "jni_util.h"
#include "faiss_index_service.h"
#include <jni.h>

namespace knn_jni {
namespace faiss_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);
jstring indexPathJ, jobject parametersJ, IndexService* indexService);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
Expand All @@ -33,6 +34,11 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Load a binary index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

Expand Down Expand Up @@ -68,6 +74,12 @@ namespace knn_jni {
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Execute a query against the binary index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

Expand Down
7 changes: 7 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ namespace knn_jni {

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<uint8_t> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

Expand All @@ -79,6 +81,8 @@ namespace knn_jni {
// ------------------------------ MISC HELPERS ------------------------------
virtual int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) = 0;

virtual int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) = 0;

virtual int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) = 0;

virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;
Expand Down Expand Up @@ -146,6 +150,7 @@ namespace knn_jni {
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim);
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ);
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
Expand All @@ -168,6 +173,7 @@ namespace knn_jni {
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand All @@ -193,6 +199,7 @@ namespace knn_jni {
extern const std::string COSINESIMIL;
extern const std::string INNER_PRODUCT;
extern const std::string NEG_DOT_PRODUCT;
extern const std::string HAMMING_BIT;

extern const std::string NPROBES;
extern const std::string COARSE_QUANTIZER;
Expand Down
29 changes: 27 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifdef __cplusplus
extern "C" {
#endif

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndex
Expand All @@ -26,6 +27,14 @@ extern "C" {
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createBinaryIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand All @@ -42,6 +51,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndex
* Signature: (Ljava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
Expand Down Expand Up @@ -69,19 +86,27 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex
* Signature: (J[FI[Ljava/util/MapI)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FILjava/util/Map[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndexWithFilter
* Signature: (J[FI[JLjava/util/MapI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FILjava/util/Map[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryBIndexWithFilter
* Signature: (J[BILjava/util/Map[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
Loading

0 comments on commit 083f668

Please sign in to comment.