Skip to content

Commit

Permalink
Add support for radial search in exact search (opensearch-project#2174)
Browse files Browse the repository at this point in the history
* Add support for radial search in exact search

When threshold value is set, knn plugin will not be creating graph.
Hence, when search request is trigged during that time, exact search
will return valid results. However, radial search was never included
as part of exact search. This will break radial search when threshold
is added and radial search is requested. In this commit, new method
is introduced to accept min score and return documents that are greater
than min score, similar to how radial search is performed by native
engines. This search is independent of engine, but, radial search is
supported only for FAISS engine out of all native engines.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
---------

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Oct 8, 2024
1 parent 1e90b84 commit 87b1d0b
Show file tree
Hide file tree
Showing 8 changed files with 486 additions and 20 deletions.
55 changes: 49 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.knn.index.query;

import com.google.common.base.Predicates;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NonNull;
import lombok.Value;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
Expand All @@ -36,7 +39,9 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.function.Predicate;

@Log4j2
@AllArgsConstructor
Expand All @@ -55,11 +60,41 @@ public class ExactSearcher {
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopK(iterator, exactSearcherContext.getK());
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
}

/**
* Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search.
* Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores
* to filter out the documents that does not have given min score.
* @param leafReaderContext
* @param exactSearcherContext
* @param iterator {@link KNNIterator}
* @return Map of docId and score
* @throws IOException exception raised by iterator during traversal
*/
private Map<Integer, Float> doRadialSearch(
LeafReaderContext leafReaderContext,
ExactSearcherContext exactSearcherContext,
KNNIterator iterator
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
if (KNNEngine.FAISS != engine) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine));
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);
final float minScore = spaceType.scoreTranslation(knnQuery.getRadius());
return filterDocsByMinScore(exactSearcherContext, iterator, minScore);
}

private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOException {
Expand All @@ -71,15 +106,17 @@ private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOExceptio
return docToScore;
}

private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOException {
private Map<Integer, Float> searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate<Float> filterScore)
throws IOException {
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(k, true);
final HitQueue queue = new HitQueue(limit, true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
int docId;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (iterator.score() > topDoc.score) {
topDoc.score = iterator.score();
final float currentScore = iterator.score();
if (filterScore.test(currentScore) && currentScore > topDoc.score) {
topDoc.score = currentScore;
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
// have seen till now on top.
Expand All @@ -98,10 +135,16 @@ private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOExc
final ScoreDoc doc = queue.pop();
docToScore.put(doc.doc, doc.score);
}

return docToScore;
}

private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore)
throws IOException {
int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow();
Predicate<Float> scoreGreaterThanOrEqualToMinScore = score -> score >= minScore;
return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore);
}

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
Expand Down
15 changes: 9 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -95,8 +96,13 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
}

public static void initialize(ModelDao modelDao) {
initialize(modelDao, new ExactSearcher(modelDao));
}

@VisibleForTesting
static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) {
KNNWeight.modelDao = modelDao;
KNNWeight.DEFAULT_EXACT_SEARCHER = new ExactSearcher(modelDao);
KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher;
}

@Override
Expand Down Expand Up @@ -204,8 +210,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) {

private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.k(k)
.isParentHits(true)
.k(k)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
Expand Down Expand Up @@ -403,12 +409,9 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
filterIdsCount,
KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName())
);
if (knnQuery.getRadius() != null) {
return false;
}
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (filterIdsCount <= knnQuery.getK()) {
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) {
return true;
}
// See user has defined Exact Search filtered threshold. if yes, then use that setting.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest
.indexName(indexName)
.parentsFilter(parentFilter)
.radius(radius)
.vectorDataType(vectorDataType)
.methodParameters(methodParameters)
.context(knnQueryContext)
.filterQuery(filterQuery)
Expand Down
119 changes: 115 additions & 4 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public class FaissIT extends KNNRestTestCase {
private static final String INTEGER_FIELD_NAME = "int_field";
private static final String FILED_TYPE_INTEGER = "integer";
private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field";
public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1;

static TestUtils.TestData testData;

Expand Down Expand Up @@ -622,10 +623,11 @@ public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCrea

// Assert we have the right number of documents in the index
assertEquals(numDocs, getDocCount(indexName));
// KNN Query should return empty result

final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1);
final List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(0, results.size());
// expect result due to exact search
assertEquals(1, results.size());

deleteKNNIndex(indexName);
validateGraphEviction();
Expand Down Expand Up @@ -681,7 +683,7 @@ public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph()
// KNN Query should return empty result
final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1);
final List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(0, results.size());
assertEquals(1, results.size());

// update index setting to build graph and do force merge
// update build vector data structure setting
Expand Down Expand Up @@ -1826,6 +1828,111 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {
validateGraphEviction();
}

@SneakyThrows
public void testEndToEnd_whenDoRadiusSearch_whenNoGraphFileIsCreated_whenDistanceThreshold_thenSucceed() {
final SpaceType spaceType = SpaceType.L2;

final List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
final List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
final List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);

final Integer dimension = testData.indexData.vectors[0].length;
final Settings knnIndexSettings = buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD);

// Create an index
final XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, knnIndexSettings, builder.toString());

// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
INDEX_NAME,
Integer.toString(testData.indexData.docs[i]),
FIELD_NAME,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents
refreshAllNonSystemIndices();
assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME));

final float distance = 300000000000f;
final List<List<KNNResult>> resultsFromDistance = validateRadiusSearchResults(
INDEX_NAME,
FIELD_NAME,
testData.queries,
distance,
null,
spaceType,
null,
null
);
assertFalse(resultsFromDistance.isEmpty());
resultsFromDistance.forEach(result -> { assertFalse(result.isEmpty()); });
final float score = spaceType.scoreTranslation(distance);
final List<List<KNNResult>> resultsFromScore = validateRadiusSearchResults(
INDEX_NAME,
FIELD_NAME,
testData.queries,
null,
score,
spaceType,
null,
null
);
assertFalse(resultsFromScore.isEmpty());
resultsFromScore.forEach(result -> { assertFalse(result.isEmpty()); });

// Delete index
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testRadialQueryWithFilter_whenNoGraphIsCreated_thenSuccess() {
setupKNNIndexForFilterQuery(buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD));

final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } };
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red");
List<String> expectedDocIds = Arrays.asList(DOC_ID_3);

float distance = 15f;
List<List<KNNResult>> queryResult = validateRadiusSearchResults(
INDEX_NAME,
FIELD_NAME,
searchVector,
distance,
null,
SpaceType.L2,
termQueryBuilder,
null
);

assertEquals(1, queryResult.get(0).size());
assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId());

// Delete index
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() {
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1898,6 +2005,10 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful(
}

protected void setupKNNIndexForFilterQuery() throws Exception {
setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings());
}

protected void setupKNNIndexForFilterQuery(Settings settings) throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
Expand All @@ -1915,7 +2026,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception {
.endObject();
final String mapping = builder.toString();

createKnnIndex(INDEX_NAME, mapping);
createKnnIndex(INDEX_NAME, settings, mapping);

addKnnDocWithAttributes(
DOC_ID_1,
Expand Down
Loading

0 comments on commit 87b1d0b

Please sign in to comment.