Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support filter and nested field in faiss engine radial search #1652

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
* Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617)
* Support filter and nested field in faiss engine radial search [#1652](https://github.com/opensearch-project/k-NN/pull/1652)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
2 changes: 1 addition & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ endif ()
# build workflow once, it can cause issues because git commits require that the user and the user's email be set.
# See https://github.com/opensearch-project/k-NN/issues/1651. So, we provide a flag that allows users to select between
# the two
if(NOT DEFINED COMMIT_LIB_PATCHES OR ${COMMIT_LIB_PATCHES} STREQUAL true)
if(NOT DEFINED COMMIT_LIB_PATCHES OR "${COMMIT_LIB_PATCHES}" STREQUAL true)
set(GIT_PATCH_COMMAND am)
else()
set(GIT_PATCH_COMMAND apply)
Expand Down
3 changes: 2 additions & 1 deletion jni/cmake/init-faiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ if (NOT EXISTS ${FAISS_REPO_DIR})
endif ()

# Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu.
find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH)
find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH)

# If it exists, apply patches
if (EXISTS ${PATCH_FILE})
message(STATUS "Applying custom patches.")
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
if(RESULT_CODE)
message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}")
endif()
Expand Down
20 changes: 19 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,35 @@ namespace knn_jni {
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @param filterIdsJ - the filter ids
* @param filterIdsTypeJ - the filter ids type
* @param parentIdsJ - the parent ids
*
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

/*
* Perform a range search against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @param parentIdsJ - the parent ids
*
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultsWindowJ);
jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ);
}
}

Expand Down
14 changes: 11 additions & 3 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,19 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[F[F)J
* Method: rangeSearchIndexWithFilter
* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint);
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jintArray);

#ifdef __cplusplus
}
Expand Down
Copy link
Member Author

@junqiu-lei junqiu-lei Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check this commit for the patch updated code. CC: @navneet1v

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
From af6770b505a32b2c4eab2036d2509dec4b137f28 Mon Sep 17 00:00:00 2001
From: Junqiu Lei <[email protected]>
Date: Tue, 23 Apr 2024 17:18:56 -0700
Subject: [PATCH] Custom patch to support range search params

Signed-off-by: Junqiu Lei <[email protected]>
---
faiss/IndexIDMap.cpp | 28 ++++++++++++++++++++++++----
1 file changed, 24 insertions(+), 4 deletions(-)

diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
index 3f375e7b..11f3a847 100644
--- a/faiss/IndexIDMap.cpp
+++ b/faiss/IndexIDMap.cpp
@@ -176,11 +176,31 @@ void IndexIDMapTemplate<IndexT>::range_search(
RangeSearchResult* result,
const SearchParameters* params) const {
if (params) {
- SearchParameters internal_search_parameters;
- IDSelectorTranslated id_selector_translated(id_map, params->sel);
- internal_search_parameters.sel = &id_selector_translated;
+ IDSelectorTranslated this_idtrans(this->id_map, nullptr);
+ ScopedSelChange sel_change;
+ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr);
+ ScopedGrpChange grp_change;
+
+ if (params->sel) {
+ auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
+
+ if (!idtrans) {
+ auto params_non_const = const_cast<SearchParameters*>(params);
+ this_idtrans.sel = params->sel;
+ sel_change.set(params_non_const, &this_idtrans);
+ }
+ }
+
+ if (params->grp) {
+ auto idtrans = dynamic_cast<const IDGrouperTranslated*>(params->grp);

- index->range_search(n, x, radius, result, &internal_search_parameters);
+ if (!idtrans) {
+ auto params_non_const = const_cast<SearchParameters*>(params);
+ this_idgrptrans.grp = params->grp;
+ grp_change.set(params_non_const, &this_idgrptrans);
+ }
+ }
+ index->range_search(n, x, radius, result, params);
} else {
index->range_search(n, x, radius, result);
}
--
2.39.0

71 changes: 69 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,12 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) {
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}
Expand All @@ -605,7 +610,69 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniU
// The res will be freed by ~RangeSearchResult() in FAISS
// The second parameter is always true, as lims is allocated by FAISS
faiss::RangeSearchResult res(1, true);
indexReader->range_search(1, rawQueryVector, radiusJ, &res);

if(filterIdsJ != nullptr) {
jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
if(filterIdsTypeJ == BITMAP) {
idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray));
} else {
faiss::idx_t* batchIndices = reinterpret_cast<faiss::idx_t*>(filteredIdsArray);
idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices));
}
faiss::SearchParameters *searchParameters;
faiss::SearchParametersHNSW hnswParams;
faiss::SearchParametersIVF ivfParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
hnswParams.sel = idSelector.get();
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);
if(ivfReader || ivfFlatReader) {
ivfParams.sel = idSelector.get();
searchParameters = &ivfParams;
}
}
try {
indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr && parentIdsJ != nullptr) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
searchParameters = &hnswParams;
}
try {
indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT);
throw;
}
}

// lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries),
// lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN,
Expand Down
19 changes: 17 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,26 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ)
jfloat radiusJ, jint maxResultWindowJ,
jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ);
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ,
jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ,
maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
Loading
Loading