Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout support to AbstractKnnVectorQuery #13202

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ Improvements
implementation is the ConcurrentMergeScheduler and the Lucene99HnswVectorsFormat will use it if no other
executor is provided. (Ben Trent)

* GITHUB#13202: Early terminate graph and exact searches of AbstractKnnVectorQuery to follow timeout set from
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)

Optimizations
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
if (knnCollector.earlyTerminated()) {
break;
}
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
Expand Down Expand Up @@ -279,6 +282,9 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
if (knnCollector.earlyTerminated()) {
break;
}
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -81,7 +82,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
filterWeight = null;
}

KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
Expand All @@ -99,9 +102,11 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
}

private TopDocs searchLeaf(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
Expand All @@ -111,13 +116,15 @@ private TopDocs searchLeaf(
}

private TopDocs getLeafResults(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
}

Scorer scorer = filterWeight.scorer(ctx);
Expand All @@ -127,21 +134,24 @@ private TopDocs getLeafResults(

BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
final int cost = acceptDocs.cardinality();
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();

if (cost <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
}

// Perform the approximate kNN search
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
}
}

Expand Down Expand Up @@ -178,7 +188,8 @@ abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi
throws IOException;

// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
Expand All @@ -192,9 +203,16 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
}
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}

boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;

Expand All @@ -216,7 +234,7 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
topScoreDocs[i] = queue.pop();
}

TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,14 @@ public TopDocs searchAfter(ScoreDoc after, Query query, int numHits) throws IOEx
return search(query, manager);
}

/**
* Get the configured {@link QueryTimeout} for all searches that run through this {@link
* IndexSearcher}, or {@code null} if not set.
*/
public QueryTimeout getTimeout() {
return this.queryTimeout;
}

/** Set a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */
public void setTimeout(QueryTimeout queryTimeout) {
this.queryTimeout = queryTimeout;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;

/** A {@link KnnCollectorManager} that collects results with a timeout. */
public class TimeLimitingKnnCollectorManager implements KnnCollectorManager {
private final KnnCollectorManager delegate;
private final QueryTimeout queryTimeout;

public TimeLimitingKnnCollectorManager(KnnCollectorManager delegate, QueryTimeout timeout) {
this.delegate = delegate;
this.queryTimeout = timeout;
}

/** Get the configured {@link QueryTimeout} for terminating graph and exact searches. */
public QueryTimeout getQueryTimeout() {
return queryTimeout;
}

@Override
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
KnnCollector collector = delegate.newCollector(visitedLimit, context);
if (queryTimeout == null) {
return collector;
}
return new KnnCollector() {
@Override
public boolean earlyTerminated() {
return queryTimeout.shouldExit() || collector.earlyTerminated();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so both searchLevel() and findBestEntryPoint() in HnswGraphSearcher check earlyTerminated() on their collector to abort the search, so configuring the timeout check allows us to preempt the approximate search.

}

@Override
public void incVisitedCount(int count) {
collector.incVisitedCount(count);
}

@Override
public long visitedCount() {
return collector.visitedCount();
}

@Override
public long visitLimit() {
return collector.visitLimit();
}

@Override
public int k() {
return collector.k();
}

@Override
public boolean collect(int docId, float similarity) {
return collector.collect(docId, similarity);
}

@Override
public float minCompetitiveSimilarity() {
return collector.minCompetitiveSimilarity();
}

@Override
public TopDocs topDocs() {
TopDocs docs = collector.topDocs();

// Mark results as partial if timeout is met
TotalHits.Relation relation =
queryTimeout.shouldExit()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: docs.totalHits.relation;

return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs);
Comment on lines +85 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply return collector.topDocs(); and let collectors decide how to handle this? All implementations of AbstractKnnCollector already set relation to GTE based on earlyTerminated() check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the collector.topDocs() will set the relation to GTE in case of timeouts (it will only set this when the visitLimit() is crossed, because it will use its own internal earlyTerminated() function that does not have information about the QueryTimeout)

}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -765,6 +766,34 @@ public void testBitSetQuery() throws IOException {
}
}

/** Test that the query times out correctly. */
public void testTimeout() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);

AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2);
AbstractKnnVectorQuery exactQuery =
getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery());

assertEquals(2, searcher.count(query)); // Expect some results without timeout
assertEquals(3, searcher.count(exactQuery)); // Same for exact search

searcher.setTimeout(() -> true); // Immediately timeout
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
assertEquals(0, searcher.count(exactQuery)); // Same for exact search

searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc
// Note: This depends on the HNSW graph having just one layer,
// would be 0 in case of multiple layers
assertEquals(1, searcher.count(query)); // Expect only 1 result

searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc
assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result
}
}

/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
Directory getIndexStore(String field, float[]... contents) throws IOException {
return getIndexStore(field, VectorSimilarityFunction.EUCLIDEAN, contents);
Expand Down Expand Up @@ -1006,4 +1035,21 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
}
}
}

private static class CountingQueryTimeout implements QueryTimeout {
private int remaining;

public CountingQueryTimeout(int count) {
remaining = count;
}

@Override
public boolean shouldExit() {
if (remaining > 0) {
remaining--;
return false;
}
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestVectorUtil;
Expand Down Expand Up @@ -109,7 +110,8 @@ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter)
}

@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
Expand Down Expand Up @@ -258,7 +259,8 @@ public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter)
}

@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}

Expand Down
Loading