Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout support to AbstractKnnVectorQuery #13202

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -81,7 +82,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
filterWeight = null;
}

KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
Expand All @@ -99,9 +102,11 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
}

private TopDocs searchLeaf(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
Expand All @@ -111,13 +116,15 @@ private TopDocs searchLeaf(
}

private TopDocs getLeafResults(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
}

Scorer scorer = filterWeight.scorer(ctx);
Expand All @@ -127,21 +134,24 @@ private TopDocs getLeafResults(

BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
final int cost = acceptDocs.cardinality();
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();

if (cost <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
}

// Perform the approximate kNN search
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
}
}

Expand Down Expand Up @@ -178,7 +188,8 @@ abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi
throws IOException;

// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
Expand All @@ -192,9 +203,16 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
}
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}

boolean advanced = vectorScorer.advanceExact(doc);
assert advanced;

Expand All @@ -216,7 +234,7 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
topScoreDocs[i] = queue.pop();
}

TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ public TopDocs searchAfter(ScoreDoc after, Query query, int numHits) throws IOEx
return search(query, manager);
}

/** Get a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */
public QueryTimeout getTimeout() {
return this.queryTimeout;
}

/** Set a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */
public void setTimeout(QueryTimeout queryTimeout) {
this.queryTimeout = queryTimeout;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;

/** A {@link KnnCollectorManager} that collects results with a timeout. */
public class TimeLimitingKnnCollectorManager implements KnnCollectorManager {
private final KnnCollectorManager delegate;
private final QueryTimeout queryTimeout;

public TimeLimitingKnnCollectorManager(KnnCollectorManager delegate, QueryTimeout timeout) {
this.delegate = delegate;
this.queryTimeout = timeout;
}

/** Get the {@link QueryTimeout} for terminating graph searches. */
public QueryTimeout getQueryTimeout() {
return queryTimeout;
}

@Override
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
KnnCollector collector = delegate.newCollector(visitedLimit, context);
if (queryTimeout == null) {
return collector;
}
return new KnnCollector() {
@Override
public boolean earlyTerminated() {
return queryTimeout.shouldExit() || collector.earlyTerminated();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so both searchLevel() and findBestEntryPoint() in HnswGraphSearcher check earlyTerminated() on their collector to abort the search, so configuring the timeout check allows us to preempt the approximate search.

}

@Override
public void incVisitedCount(int count) {
collector.incVisitedCount(count);
}

@Override
public long visitedCount() {
return collector.visitedCount();
}

@Override
public long visitLimit() {
return collector.visitLimit();
}

@Override
public int k() {
return collector.k();
}

@Override
public boolean collect(int docId, float similarity) {
return collector.collect(docId, similarity);
}

@Override
public float minCompetitiveSimilarity() {
return collector.minCompetitiveSimilarity();
}

@Override
public TopDocs topDocs() {
TopDocs docs = collector.topDocs();

// Mark results as partial if timeout is met
TotalHits.Relation relation =
queryTimeout.shouldExit()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: docs.totalHits.relation;

return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs);
Comment on lines +85 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply return collector.topDocs(); and let collectors decide how to handle this? All implementations of AbstractKnnCollector already set relation to GTE based on earlyTerminated() check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the collector.topDocs() will set the relation to GTE in case of timeouts (it will only set this when the visitLimit() is crossed, because it will use its own internal earlyTerminated() function that does not have information about the QueryTimeout)

}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestVectorUtil;
Expand Down Expand Up @@ -102,14 +103,34 @@ public void testVectorEncodingMismatch() throws IOException {
}
}

public void testTimeout() throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also add a test for the partial result case. You could create a mock query timeout that returns false only for the first time it's called, and then flips to returning true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

There was a consideration here that the number of levels in the HNSW graph should be == 1, because if the timeout is hit while finding the best entry point for the last level, we haven't collected any results yet. I think this should be fine as we're only indexing 3 vectors, and running these tests for a few thousand times did not give an error. Added a note about this as well

Also fixed another place where the timeout needs to be checked

try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);

AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2);
AbstractKnnVectorQuery exactQuery =
getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery());

assertEquals(2, searcher.count(query)); // Expect some results without timeout
assertEquals(3, searcher.count(exactQuery)); // Same for exact search

searcher.setTimeout(() -> true); // Immediately timeout
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
assertEquals(0, searcher.count(exactQuery)); // Same for exact search
}
}

private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {

public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, target, k, filter);
}

@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
Expand Down Expand Up @@ -251,14 +252,34 @@ public void testDocAndScoreQueryBasics() throws IOException {
}
}

public void testTimeout() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);

AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2);
AbstractKnnVectorQuery exactQuery =
getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery());

assertEquals(2, searcher.count(query)); // Expect some results without timeout
assertEquals(3, searcher.count(exactQuery)); // Same for exact search

searcher.setTimeout(() -> true); // Immediately timeout
assertEquals(0, searcher.count(query)); // Expect no results with the timeout
assertEquals(0, searcher.count(exactQuery)); // Same for exact search
}
}

private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {

public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
super(field, target, k, filter);
}

@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
throw new UnsupportedOperationException("exact search is not supported");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
Expand Down Expand Up @@ -77,7 +78,8 @@ public DiversifyingChildrenByteKnnVectorQuery(
}

@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
Expand All @@ -100,8 +102,15 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
fi.getVectorSimilarityFunction());
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}

float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
Expand All @@ -120,7 +129,7 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
topScoreDocs[i] = queue.pop();
}

TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
Expand Down Expand Up @@ -77,7 +78,8 @@ public DiversifyingChildrenFloatKnnVectorQuery(
}

@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
Expand All @@ -100,8 +102,15 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
fi.getVectorSimilarityFunction());
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
Comment on lines 107 to +112
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wanted some opinions here: we're checking the timeout in exact search of DiversifyingChildren[Byte|Float]KnnVectorQuery once per-parent (as opposed to once per document elsewhere)

Should we update this to once per-child as well?

Copy link
Contributor

@vigyasharma vigyasharma Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Timeouts are inevitably approximate, so I guess it should be fine. Also seems like something that we can easily change in a follow up PR is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks!


float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
Expand All @@ -120,7 +129,7 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
topScoreDocs[i] = queue.pop();
}

TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}

Expand Down
Loading