Skip to content

Commit

Permalink
Adds Nprobe as a method parameter in knn query (#1758) (#1792)
Browse files Browse the repository at this point in the history
Adds integration test for invalid method parameters (#1782)

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas authored Jul 16, 2024
1 parent 9705144 commit b422466
Show file tree
Hide file tree
Showing 15 changed files with 444 additions and 78 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Adds dynamic query parameter nprobes [#1792](https://github.com/opensearch-project/k-NN/pull/1792)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
* Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826)
### Enhancements
Expand Down
14 changes: 12 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
if (methodParamsJ != nullptr) {
methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ);
}

// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
Expand Down Expand Up @@ -358,7 +357,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);

if(ivfReader || ivfFlatReader) {
int indexNprobe = ivfReader == nullptr ? ivfFlatReader->nprobe : ivfReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
ivfParams.sel = idSelector.get();
searchParameters = &ivfParams;
}
Expand All @@ -374,17 +376,25 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
faiss::SearchParametersIVF ivfParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr) {
if(hnswReader != nullptr) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
if (ivfReader) {
int indexNprobe = ivfReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
searchParameters = &ivfParams;
}
}
try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
Expand Down
191 changes: 154 additions & 37 deletions jni/tests/faiss_wrapper_unit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "test_util.h"
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIDMap.h"
#include "faiss/IndexIVFFlat.h"
#include "faiss/IndexIVFPQ.h"

using ::testing::NiceMock;

Expand All @@ -30,6 +32,46 @@ struct FaissMockIndex : faiss::IndexHNSW {
}
};

struct MockIVFIndex : faiss::IndexIVFFlat {
explicit MockIVFIndex() = default;
};

struct MockIVFIdMap : faiss::IndexIDMap {
mutable idx_t nCalled{};
mutable const float *xCalled{};
mutable idx_t kCalled{};
mutable float *distancesCalled{};
mutable idx_t *labelsCalled{};
mutable const faiss::SearchParametersIVF *paramsCalled{};

explicit MockIVFIdMap(MockIVFIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
}

void search(
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels,
const faiss::SearchParameters *params) const override {
nCalled = n;
xCalled = x;
kCalled = k;
distancesCalled = distances;
labelsCalled = labels;
paramsCalled = dynamic_cast<const faiss::SearchParametersIVF *>(params);
}

void resetMock() const {
nCalled = 0;
xCalled = nullptr;
kCalled = 0;
distancesCalled = nullptr;
labelsCalled = nullptr;
paramsCalled = nullptr;
}
};

struct FaissMockIdMap : faiss::IndexIDMap {
mutable idx_t nCalled{};
mutable const float *xCalled{};
Expand Down Expand Up @@ -83,13 +125,14 @@ struct FaissMockIdMap : faiss::IndexIDMap {
}
};

struct QueryIndexHNSWTestInput {
std::string description;
struct QueryIndexInput {
string description;
int k;
int efSearch;
int filterIdType;
bool filterIdsPresent;
bool parentIdsPresent;
int efSearch;
int nprobe;
};

struct RangeSearchTestInput {
Expand All @@ -101,9 +144,9 @@ struct RangeSearchTestInput {
bool parentIdsPresent;
};

class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<QueryIndexHNSWTestInput> {
class FaissWrapperParameterizedTestFixture : public testing::TestWithParam<QueryIndexInput> {
public:
FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) {
FaissWrapperParameterizedTestFixture() : index_(3), id_map_(&index_) {
index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere
}

Expand All @@ -112,9 +155,9 @@ class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<Quer
FaissMockIdMap id_map_;
};

class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithParam<RangeSearchTestInput> {
class FaissWrapperParameterizedRangeSearchTestFixture : public testing::TestWithParam<RangeSearchTestInput> {
public:
FaissWrapperParametrizedRangeSearchTestFixture() : index_(3), id_map_(&index_) {
FaissWrapperParameterizedRangeSearchTestFixture() : index_(3), id_map_(&index_) {
index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere
}

Expand All @@ -123,16 +166,24 @@ class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithP
FaissMockIdMap id_map_;
};

namespace query_index_test {
class FaissWrapperIVFQueryTestFixture : public testing::TestWithParam<QueryIndexInput> {
public:
FaissWrapperIVFQueryTestFixture() : ivf_id_map_(&ivf_index_) {
ivf_index_.nprobe = 100;
};

std::unordered_map<std::string, jobject> methodParams;
protected:
MockIVFIndex ivf_index_;
MockIVFIdMap ivf_id_map_;
};

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) {
// Given
namespace query_index_test {
TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexHNSWTests) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexHNSWTestInput const &input = GetParam();
QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

int efSearch = input.efSearch;
Expand Down Expand Up @@ -184,24 +235,23 @@ namespace query_index_test {

INSTANTIATE_TEST_CASE_P(
QueryIndexHNSWTests,
FaissWrappeterParametrizedTestFixture,
FaissWrapperParameterizedTestFixture,
::testing::Values(
QueryIndexHNSWTestInput{"algoParams present, parent absent", 10, 200, 0, false, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent", 10, -1, 0, false, false},
QueryIndexHNSWTestInput{"algoParams present, parent present", 10, 200, 0, false, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present", 10, -1, 0, false, true}
QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, 200, -1 },
QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, -1, -1 },
QueryIndexInput {"algoParams present, parent present", 10, 0, false, true, 200, -1 },
QueryIndexInput {"algoParams absent, parent present", 10, 0, false, true, -1, -1 }
)
);
}

namespace query_index_with_filter_test {

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) {
// Given
TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexWithFilterHNSWTests) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexHNSWTestInput const &input = GetParam();
QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

std::vector<int> *parentIdPtr = nullptr;
Expand Down Expand Up @@ -267,23 +317,26 @@ namespace query_index_with_filter_test {

INSTANTIATE_TEST_CASE_P(
QueryIndexWithFilterHNSWTests,
FaissWrappeterParametrizedTestFixture,
FaissWrapperParameterizedTestFixture,
::testing::Values(
QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent", 10, 200, 0, false, false},
QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10, 200, 1, false, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present", 10, -1, 0, true, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10, -1, 1, true, false},
QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent", 10, 200, 0, false, true},
QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent, filter type 1", 10, 150, 1, false, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present", 10, -1, 0, true, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present, filter type 1",10, -1, 1, true, true}
QueryIndexInput { "algoParams present, parent absent, filter absent", 10, 0, false, false, 200, -1 },
QueryIndexInput { "algoParams present, parent absent, filter absent, filter type 1", 10, 1, false, false,
200, -1},
QueryIndexInput { "algoParams absent, parent absent, filter present", 10, 0, true, false, -1, -1},
QueryIndexInput { "algoParams absent, parent absent, filter present, filter type 1", 10, 1, true, false, -1,
-1},
QueryIndexInput { "algoParams present, parent present, filter absent", 10, 0, false, true, 200, -1 },
QueryIndexInput { "algoParams present, parent present, filter absent, filter type 1", 10, 1, false, true,
150, -1},
QueryIndexInput { "algoParams absent, parent present, filter present", 10, 0, true, true, -1, -1},
QueryIndexInput { "algoParams absent, parent present, filter present, filter type 1",10, 1, true, true, -1,
-1 }
)
);
}

namespace range_search_test {

TEST_P(FaissWrapperParametrizedRangeSearchTestFixture, RangeSearchHNSWTests) {
TEST_P(FaissWrapperParameterizedRangeSearchTestFixture, RangeSearchHNSWTests) {
// Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
Expand Down Expand Up @@ -323,6 +376,7 @@ namespace range_search_test {
std::vector<long> filter;
std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
Expand Down Expand Up @@ -356,16 +410,79 @@ namespace range_search_test {

INSTANTIATE_TEST_CASE_P(
RangeSearchHNSWTests,
FaissWrapperParametrizedRangeSearchTestFixture,
FaissWrapperParameterizedRangeSearchTestFixture,
::testing::Values(
RangeSearchTestInput{"algoParams present, parent absent, filter absent", 10.0f, 200, 0, false, false},
RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false, false},
RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false
, false},
RangeSearchTestInput{"algoParams absent, parent absent, filter present", 10.0f, -1, 0, true, false},
RangeSearchTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10.0f, -1, 1, true, false},
RangeSearchTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10.0f, -1, 1, true,
false},
RangeSearchTestInput{"algoParams present, parent present, filter absent", 10.0f, 200, 0, false, true},
RangeSearchTestInput{"algoParams present, parent present, filter absent, filter type 1", 10.0f, 150, 1, false, true},
RangeSearchTestInput{"algoParams present, parent present, filter absent, filter type 1", 10.0f, 150, 1,
false, true},
RangeSearchTestInput{"algoParams absent, parent present, filter present", 10.0f, -1, 0, true, true},
RangeSearchTestInput{"algoParams absent, parent present, filter present, filter type 1", 10.0f, -1, 1, true, true}
RangeSearchTestInput{"algoParams absent, parent present, filter present, filter type 1", 10.0f, -1, 1, true,
true}
)
);
}

namespace query_index_with_filter_test_ivf {
TEST_P(FaissWrapperIVFQueryTestFixture, QueryIndexIVFTest) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

int nprobe = input.nprobe;
int expectedNprobe = 100; //default set in mock
std::unordered_map<std::string, jobject> methodParams;
if (nprobe != -1) {
expectedNprobe = input.nprobe;
methodParams[knn_jni::NPROBES] = reinterpret_cast<jobject>(&nprobe);
}

std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
filterptr = &filter;
}

// When
knn_jni::faiss_wrapper::QueryIndex_WithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&ivf_id_map_),
reinterpret_cast<jfloatArray>(&query), input.k, reinterpret_cast<jobject>(&methodParams),
reinterpret_cast<jlongArray>(filterptr),
input.filterIdType,
nullptr);

//Then
int actualEfSearch = ivf_id_map_.paramsCalled->nprobe;
// Asserting the captured argument
EXPECT_EQ(input.k, ivf_id_map_.kCalled);
EXPECT_EQ(expectedNprobe, actualEfSearch);
if (input.parentIdsPresent) {
faiss::IDGrouper *grouper = ivf_id_map_.paramsCalled->grp;
EXPECT_TRUE(grouper != nullptr);
}
ivf_id_map_.resetMock();
}

INSTANTIATE_TEST_CASE_P(
QueryIndexIVFTest,
FaissWrapperIVFQueryTestFixture,
::testing::Values(
QueryIndexInput{"algoParams present, parent absent", 10, 0, false, false, -1, 200 },
QueryIndexInput{"algoParams present, parent absent", 10,0, false, false, -1, -1 },
QueryIndexInput{"algoParams present, parent present", 10, 0, true, true, -1, 200 },
QueryIndexInput{"algoParams absent, parent present", 10, 0, true, true, -1, -1 }
)
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
Expand All @@ -74,6 +75,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE);
public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE);
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);
public static final int K_MAX = 10000;
/**
Expand Down Expand Up @@ -222,9 +224,8 @@ private void validate() {

if (k != null) {
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires k to be in the range (0, %d]", NAME, K_MAX)
);
final String errorMessage = "[" + NAME + "] requires k to be in the range (0, " + K_MAX + "]";
throw new IllegalArgumentException(errorMessage);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME;

/**
* Note: This parser is used by neural plugin as well, breaking changes will require changes in neural as well
*/
@EqualsAndHashCode
@Getter
@AllArgsConstructor
Expand Down
Loading

0 comments on commit b422466

Please sign in to comment.