Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add radial search feature to main branch #1617

Merged
merged 6 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
* Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
12 changes: 12 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ namespace knn_jni {
// Return the serialized representation
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultsWindowJ);
}
}

Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[F[F)J
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint);

#ifdef __cplusplus
}
#endif
Expand Down
44 changes: 44 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,47 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {

throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present.");
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to indexReader");
}

float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

// The res will be freed by ~RangeSearchResult() in FAISS
// The second parameter is always true, as lims is allocated by FAISS
faiss::RangeSearchResult res(1, true);
indexReader->range_search(1, rawQueryVector, radiusJ, &res);

// lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries),
// lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN,
// res.lims[0] is always 0, and res.lims[1] gives the total number of matching entries found.
int resultSize = res.lims[1];

// Limit the result size to maxResultWindowJ so that we don't return more than the max result window
// TODO: In the future, we should prevent this via FAISS's ResultHandler.
if (resultSize > maxResultWindowJ) {
resultSize = maxResultWindowJ;
}

jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult");
jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", "<init>");

jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr);

jobject result;
for(int i = 0; i < resultSize; ++i) {
result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]);
jniUtil->SetObjectArrayElement(env, results, i, result);
}

return results;
}
14 changes: 14 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,17 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors

return (jlong) vect;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}
113 changes: 113 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,116 @@ TEST(FaissInitAndSetSharedIndexState, BasicAssertions) {
ASSERT_EQ(1, ivfpqIndex->use_precomputed_table);
knn_jni::faiss_wrapper::FreeSharedIndexState(sharedModelAddress);
}

TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
// Define the index data
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

int maxResultWindow = 20000;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow)));

// assert result size is not 0
ASSERT_NE(0, results->size());


// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}

TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){
// Define the index data
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

int maxResultWindow = 10;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow)));

// assert result size is not 0
ASSERT_NE(0, results->size());
// assert result size is equal to maxResultWindow
ASSERT_EQ(maxResultWindow, results->size());

// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public class KNNConstants {
public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ public class IndexUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH);
}
};

Expand Down
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ public float scoreTranslation(float rawScore) {
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;
}

@Override
public float scoreToDistanceTranslation(float score) {
if (score == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue()));
}
return 1 / score - 1;
}
},
COSINESIMIL("cosinesimil") {
@Override
Expand Down Expand Up @@ -170,4 +178,14 @@ public static SpaceType getSpace(String spaceTypeName) {
}
throw new IllegalArgumentException("Unable to find space: " + spaceTypeName);
}

/**
* Translate a score to a distance for this space type
*
* @param score score to translate
* @return translated distance
*/
public float scoreToDistanceTranslation(float score) {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Optional;

/**
* Base class for creating vector search queries.
*/
@Log4j2
public abstract class BaseQueryFactory {
/**
* DTO object to hold data required to create a Query instance.
*/
@AllArgsConstructor
@Builder
@Getter
public static class CreateQueryRequest {
@NonNull
private KNNEngine knnEngine;
@NonNull
private String indexName;
private String fieldName;
private float[] vector;
private byte[] byteVector;
private VectorDataType vectorDataType;
private Integer k;
private Float radius;
private QueryBuilder filter;
private QueryShardContext context;

public Optional<QueryBuilder> getFilter() {
return Optional.ofNullable(filter);
}

public Optional<QueryShardContext> getContext() {
return Optional.ofNullable(context);
}
}

/**
* Creates a query filter.
*
* @param createQueryRequest request object that has all required fields to construct the query
* @return Lucene Query
*/
protected static Query getFilterQuery(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
if (!createQueryRequest.getFilter().isPresent()) {
return null;
}

final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
log.debug(
String.format(
"Creating query with filter for index [%s], field [%s]",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName()
)
);
final Query filterQuery;
try {
filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
} catch (IOException e) {
throw new RuntimeException("Cannot create query with filter", e);
}
BitSetProducer parentFilter = queryShardContext.getParentFilter();
if (parentFilter != null) {
boolean mightMatch = new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery);
if (mightMatch) {
return filterQuery;
}
return new ToChildBlockJoinQuery(filterQuery, parentFilter);
}
return filterQuery;
}
}
Loading
Loading