Skip to content

Commit

Permalink
Support radial search in k-NN
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 18, 2024
1 parent 4abe91f commit 4398d52
Show file tree
Hide file tree
Showing 32 changed files with 2,163 additions and 151 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
* Add Clear Cache API [#740](https://github.com/opensearch-project/k-NN/pull/740)
* Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
12 changes: 12 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ namespace knn_jni {
// Return the serialized representation
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultsWindowJ);
}
}

Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[F[F)J
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint);

#ifdef __cplusplus
}
#endif
Expand Down
44 changes: 44 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,47 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {

throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present.");
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to indexReader");
}

float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

// The res will be freed by ~RangeSearchResult() in FAISS
// The second parameter is always true, as lims is allocated by FAISS
faiss::RangeSearchResult res(1, true);
indexReader->range_search(1, rawQueryVector, radiusJ, &res);

// lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries),
// lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN,
// res.lims[0] is always 0, and res.lims[1] gives the total number of matching entries found.
int resultSize = res.lims[1];

// Limit the result size to maxResultWindowJ so that we don't return more than the max result window
// TODO: In the future, we should prevent this via FAISS's ResultHandler.
if (resultSize > maxResultWindowJ) {
resultSize = maxResultWindowJ;
}

jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult");
jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", "<init>");

jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr);

jobject result;
for(int i = 0; i < resultSize; ++i) {
result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]);
jniUtil->SetObjectArrayElement(env, results, i, result);
}

return results;
}
14 changes: 14 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,17 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors

return (jlong) vect;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}
113 changes: 113 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,116 @@ TEST(FaissInitAndSetSharedIndexState, BasicAssertions) {
ASSERT_EQ(1, ivfpqIndex->use_precomputed_table);
knn_jni::faiss_wrapper::FreeSharedIndexState(sharedModelAddress);
}

TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
// Define the index data
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

int maxResultWindow = 20000;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow)));

// assert result size is not 0
ASSERT_NE(0, results->size());


// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}

TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){
// Define the index data
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

int maxResultWindow = 10;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow)));

// assert result size is not 0
ASSERT_NE(0, results->size());
// assert result size is equal to maxResultWindow
ASSERT_EQ(maxResultWindow, results->size());

// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public class KNNConstants {
public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ public class IndexUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0;
public static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH);
}
};

Expand Down
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ public float scoreTranslation(float rawScore) {
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;
}

@Override
public float scoreToDistanceTranslation(float score) {
if (score == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue()));
}
return 1 / score - 1;
}
},
COSINESIMIL("cosinesimil") {
@Override
Expand Down Expand Up @@ -170,4 +178,14 @@ public static SpaceType getSpace(String spaceTypeName) {
}
throw new IllegalArgumentException("Unable to find space: " + spaceTypeName);
}

/**
* Translate a score to a distance for this space type
*
* @param score score to translate
* @return translated distance
*/
public float scoreToDistanceTranslation(float score) {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue()));
}
}
95 changes: 95 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Optional;

/**
* Base class for creating vector search queries.
*/
@Log4j2
public abstract class BaseQueryFactory {
/**
* DTO object to hold data required to create a Query instance.
*/
@AllArgsConstructor
@Builder
@Getter
public static class CreateQueryRequest {
@NonNull
private KNNEngine knnEngine;
@NonNull
private String indexName;
private String fieldName;
private float[] vector;
private byte[] byteVector;
private VectorDataType vectorDataType;
private Integer k;
private Float radius;
private QueryBuilder filter;
private QueryShardContext context;

public Optional<QueryBuilder> getFilter() {
return Optional.ofNullable(filter);
}

public Optional<QueryShardContext> getContext() {
return Optional.ofNullable(context);
}
}

/**
* Creates a query filter.
*
* @param createQueryRequest request object that has all required fields to construct the query
* @return Lucene Query
*/
protected static Query getFilterQuery(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
if (!createQueryRequest.getFilter().isPresent()) {
return null;
}

final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
log.debug(
String.format(
"Creating query with filter for index [%s], field [%s]",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName()
)
);
final Query filterQuery;
try {
filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
} catch (IOException e) {
throw new RuntimeException("Cannot create query with filter", e);
}
BitSetProducer parentFilter = queryShardContext.getParentFilter();
if (parentFilter != null) {
boolean mightMatch = new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery);
if (mightMatch) {
return filterQuery;
}
return new ToChildBlockJoinQuery(filterQuery, parentFilter);
}
return filterQuery;
}
}
Loading

0 comments on commit 4398d52

Please sign in to comment.