Skip to content

Commit

Permalink
Adds Nprobe as a method parameter in knn query (opensearch-project#1758)
Browse files Browse the repository at this point in the history
Adds integration test for invalid method parameters (opensearch-project#1782)

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Jul 3, 2024
1 parent e7d7ec8 commit c034ec1
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 70 deletions.
14 changes: 12 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
if (methodParamsJ != nullptr) {
methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ);
}

// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
Expand Down Expand Up @@ -358,7 +357,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);

if(ivfReader || ivfFlatReader) {
int indexNprobe = ivfReader == nullptr ? ivfReader->nprobe : ivfFlatReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
ivfParams.sel = idSelector.get();
searchParameters = &ivfParams;
}
Expand All @@ -374,17 +376,25 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
faiss::SearchParametersIVF ivfParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr) {
if(hnswReader != nullptr) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
if (ivfReader) {
int indexNprobe = ivfReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
searchParameters = &ivfParams;
}
}
try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
Expand Down
174 changes: 144 additions & 30 deletions jni/tests/faiss_wrapper_unit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "faiss_wrapper.h"

#include <vector>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQ.h>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
Expand All @@ -30,6 +32,46 @@ struct MockIndex : faiss::IndexHNSW {
}
};

struct MockIVFIndex : faiss::IndexIVFFlat {
explicit MockIVFIndex() = default;
};

struct MockIVFIdMap : faiss::IndexIDMap {
mutable idx_t nCalled{};
mutable const float *xCalled{};
mutable idx_t kCalled{};
mutable float *distancesCalled{};
mutable idx_t *labelsCalled{};
mutable const faiss::SearchParametersIVF *paramsCalled{};

explicit MockIVFIdMap(MockIVFIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
}

void search(
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels,
const faiss::SearchParameters *params) const override {
nCalled = n;
xCalled = x;
kCalled = k;
distancesCalled = distances;
labelsCalled = labels;
paramsCalled = dynamic_cast<const faiss::SearchParametersIVF *>(params);
}

void resetMock() const {
nCalled = 0;
xCalled = nullptr;
kCalled = 0;
distancesCalled = nullptr;
labelsCalled = nullptr;
paramsCalled = nullptr;
}
};

struct MockIdMap : faiss::IndexIDMap {
mutable idx_t nCalled{};
mutable const float *xCalled{};
Expand Down Expand Up @@ -83,13 +125,14 @@ struct MockIdMap : faiss::IndexIDMap {
}
};

struct QueryIndexHNSWTestInput {
std::string description;
struct QueryIndexInput {
string description;
int k;
int efSearch;
int filterIdType;
bool filterIdsPresent;
bool parentIdsPresent;
int efSearch;
int nprobe;
};

struct RangeSearchTestInput {
Expand All @@ -101,9 +144,9 @@ struct RangeSearchTestInput {
bool parentIdsPresent;
};

class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<QueryIndexHNSWTestInput> {
class FaissWrapperParametrizedTestFixture : public testing::TestWithParam<QueryIndexInput> {
public:
FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) {
FaissWrapperParametrizedTestFixture() : index_(3), id_map_(&index_) {
index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere
}

Expand All @@ -123,16 +166,25 @@ class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithP
MockIdMap id_map_;
};

namespace query_index_test {
class FaissWrapperIVFQueryTestFixture : public testing::TestWithParam<QueryIndexInput> {
public:
FaissWrapperIVFQueryTestFixture() : ivf_id_map_(&ivf_index_) {
ivf_index_.nprobe = 100;
};

std::unordered_map<std::string, jobject> methodParams;
protected:
MockIVFIndex ivf_index_;
MockIVFIdMap ivf_id_map_;
};

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) {
// Given
namespace query_index_test {

TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexHNSWTests) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexHNSWTestInput const &input = GetParam();
QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

int efSearch = input.efSearch;
Expand Down Expand Up @@ -184,24 +236,23 @@ namespace query_index_test {

INSTANTIATE_TEST_CASE_P(
QueryIndexHNSWTests,
FaissWrappeterParametrizedTestFixture,
FaissWrapperParametrizedTestFixture,
::testing::Values(
QueryIndexHNSWTestInput{"algoParams present, parent absent", 10, 200, 0, false, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent", 10, -1, 0, false, false},
QueryIndexHNSWTestInput{"algoParams present, parent present", 10, 200, 0, false, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present", 10, -1, 0, false, true}
QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, 200, -1 },
QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, -1, -1 },
QueryIndexInput {"algoParams present, parent present", 10, 0, false, true, 200, -1 },
QueryIndexInput {"algoParams absent, parent present", 10, 0, false, true, -1, -1 }
)
);
}

namespace query_index_with_filter_test {

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) {
// Given
TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexWithFilterHNSWTests) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexHNSWTestInput const &input = GetParam();
QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

std::vector<int> *parentIdPtr = nullptr;
Expand Down Expand Up @@ -267,23 +318,23 @@ namespace query_index_with_filter_test {

INSTANTIATE_TEST_CASE_P(
QueryIndexWithFilterHNSWTests,
FaissWrappeterParametrizedTestFixture,
FaissWrapperParametrizedTestFixture,
::testing::Values(
QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent", 10, 200, 0, false, false},
QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10, 200, 1, false, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present", 10, -1, 0, true, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10, -1, 1, true, false},
QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent", 10, 200, 0, false, true},
QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent, filter type 1", 10, 150, 1, false, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present", 10, -1, 0, true, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present, filter type 1",10, -1, 1, true, true}
QueryIndexInput { "algoParams present, parent absent, filter absent", 10, 0, false, false, 200, -1 },
QueryIndexInput { "algoParams present, parent absent, filter absent, filter type 1", 10, 1, false, false, 200, -1},
QueryIndexInput { "algoParams absent, parent absent, filter present", 10, 0, true, false, -1, -1},
QueryIndexInput { "algoParams absent, parent absent, filter present, filter type 1", 10, 1, true, false, -1, -1},
QueryIndexInput { "algoParams present, parent present, filter absent", 10, 0, false, true, 200, -1 },
QueryIndexInput { "algoParams present, parent present, filter absent, filter type 1", 10, 1, false, true, 150, -1},
QueryIndexInput { "algoParams absent, parent present, filter present", 10, 0, true, true, -1, -1},
QueryIndexInput { "algoParams absent, parent present, filter present, filter type 1",10, 1, true, true, -1, -1 }
)
);
}

namespace range_search_test {

TEST_P(FaissWrapperParametrizedRangeSearchTestFixture, RangeSearchHNSWTests) {
TEST_P(FaissWrapperParameterizedRangeSearchTestFixture, RangeSearchHNSWTests) {
// Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
Expand Down Expand Up @@ -323,6 +374,7 @@ namespace range_search_test {
std::vector<long> filter;
std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
Expand Down Expand Up @@ -356,7 +408,7 @@ namespace range_search_test {

INSTANTIATE_TEST_CASE_P(
RangeSearchHNSWTests,
FaissWrapperParametrizedRangeSearchTestFixture,
FaissWrapperParameterizedRangeSearchTestFixture,
::testing::Values(
RangeSearchTestInput{"algoParams present, parent absent, filter absent", 10.0f, 200, 0, false, false},
RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false, false},
Expand All @@ -370,3 +422,65 @@ namespace range_search_test {
);
}

namespace query_index_with_filter_test_ivf {

TEST_P(FaissWrapperIVFQueryTestFixture, QueryIndexIVFTest) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

int nprobe = input.nprobe;
int expectedNprobe = 100; //default set in mock
std::unordered_map<std::string, jobject> methodParams;
if (nprobe != -1) {
expectedNprobe = input.nprobe;
methodParams[knn_jni::NPROBES] = reinterpret_cast<jobject>(&nprobe);
}

std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
filterptr = &filter;
}
}
// When
knn_jni::faiss_wrapper::QueryIndex_WithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&ivf_id_map_),
reinterpret_cast<jfloatArray>(&query), input.k, reinterpret_cast<jobject>(&methodParams),
reinterpret_cast<jlongArray>(filterptr),
input.filterIdType,
nullptr);

//Then
int actualEfSearch = ivf_id_map_.paramsCalled->nprobe;
// Asserting the captured argument
EXPECT_EQ(input.k, ivf_id_map_.kCalled);
EXPECT_EQ(expectedNprobe, actualEfSearch);
if (input.parentIdsPresent) {
faiss::IDGrouper *grouper = ivf_id_map_.paramsCalled->grp;
EXPECT_TRUE(grouper != nullptr);
}
ivf_id_map_.resetMock();
}

INSTANTIATE_TEST_CASE_P(
QueryIndexIVFTest,
FaissWrapperIVFQueryTestFixture,
::testing::Values(
QueryIndexInput{"algoParams present, parent absent", 10, 0, false, false, -1, 200 },
QueryIndexInput{"algoParams present, parent absent", 10,0, false, false, -1, -1 },
QueryIndexInput{"algoParams present, parent present", 10, 0, true, true, -1, 200 },
QueryIndexInput{"algoParams absent, parent present", 10, 0, true, true, -1, -1 }
)
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.util.QueryContext;
import org.opensearch.knn.indices.ModelDao;
Expand All @@ -50,6 +50,7 @@
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
Expand All @@ -73,6 +74,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE);
public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE);
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);
public static final int K_MAX = 10000;
/**
Expand Down Expand Up @@ -217,7 +219,8 @@ private void validate() {

if (k != null) {
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX));
final String errorMessage = "[" + NAME + "] requires k to be in the range (0, " + K_MAX + "]";
throw new IllegalArgumentException(errorMessage);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME;

/**
* Note: This parser is used by neural plugin as well, breaking changes will require changes in neural as well
*/
@EqualsAndHashCode
@Getter
@AllArgsConstructor
Expand Down
Loading

0 comments on commit c034ec1

Please sign in to comment.