From d700b91fed00bb0c551c995ed99e84d373b82a36 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Fri, 22 Mar 2024 19:48:30 +0000 Subject: [PATCH 1/5] Add timeout support to graph searches in AbstractKnnVectorQuery --- .../lucene/search/AbstractKnnVectorQuery.java | 4 +- .../apache/lucene/search/IndexSearcher.java | 5 + .../TimeLimitingKnnCollectorManager.java | 91 +++++++++++++++++++ .../lucene/search/TestKnnByteVectorQuery.java | 14 +++ .../search/TestKnnFloatVectorQuery.java | 14 +++ ...TestParentBlockJoinByteKnnVectorQuery.java | 17 ++++ ...estParentBlockJoinFloatKnnVectorQuery.java | 17 ++++ 7 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 8e4cf3c3d333..c294ef233b80 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -81,7 +81,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } - KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher); + KnnCollectorManager knnCollectorManager = + new TimeLimitingKnnCollectorManager( + getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index 5157b4da054f..bb937355b4a7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -483,6 +483,11 @@ public TopDocs searchAfter(ScoreDoc after, Query query, int numHits) throws IOEx return search(query, manager); } + /** Get a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */ + public QueryTimeout getTimeout() { + return this.queryTimeout; + } + /** Set a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */ public void setTimeout(QueryTimeout queryTimeout) { this.queryTimeout = queryTimeout; diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java new file mode 100644 index 000000000000..20c0d6bb4dfe --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.knn.KnnCollectorManager; + +/** A {@link KnnCollectorManager} that collects results with a timeout. */ +public class TimeLimitingKnnCollectorManager implements KnnCollectorManager { + private final KnnCollectorManager delegate; + private final QueryTimeout queryTimeout; + + public TimeLimitingKnnCollectorManager(KnnCollectorManager delegate, QueryTimeout timeout) { + this.delegate = delegate; + this.queryTimeout = timeout; + } + + /** Get the {@link QueryTimeout} for terminating graph searches. */ + public QueryTimeout getQueryTimeout() { + return queryTimeout; + } + + @Override + public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException { + KnnCollector collector = delegate.newCollector(visitedLimit, context); + if (queryTimeout == null) { + return collector; + } + return new KnnCollector() { + @Override + public boolean earlyTerminated() { + return queryTimeout.shouldExit() || collector.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + collector.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return collector.visitedCount(); + } + + @Override + public long visitLimit() { + return collector.visitLimit(); + } + + @Override + public int k() { + return collector.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + return collector.collect(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return collector.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + if (queryTimeout.shouldExit()) { + // Quickly exit + return TopDocsCollector.EMPTY_TOPDOCS; + } + return collector.topDocs(); + } + }; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 1b912ae7aad4..f6c51d39a741 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -102,6 +102,20 @@ public void testVectorEncodingMismatch() throws IOException { } } + public void testTimeout() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + } + } + private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 60969663361d..c90656dd437f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -251,6 +251,20 @@ public void testDocAndScoreQueryBasics() throws IOException { } } + public void testTimeout() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + } + } + private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java index 5b70763d044b..7f34f8153e9e 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java @@ -67,6 +67,23 @@ public void testVectorEncodingMismatch() throws IOException { } } + public void testTimeout() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + BitSetProducer parentFilter = parentFilter(reader); + IndexSearcher searcher = newSearcher(reader); + + Query query = + new DiversifyingChildrenByteKnnVectorQuery( + "field", new byte[] {1, 2}, null, 2, parentFilter); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + } + } + private static byte[] fromFloat(float[] queryVector) { byte[] query = new byte[queryVector.length]; for (int i = 0; i < queryVector.length; i++) { diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java index fee519a1e1d6..7be41a73c267 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java @@ -103,6 +103,23 @@ d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) { } } + public void testTimeout() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + BitSetProducer parentFilter = parentFilter(reader); + IndexSearcher searcher = newSearcher(reader); + + Query query = + new DiversifyingChildrenFloatKnnVectorQuery( + "field", new float[] {1, 2}, null, 2, parentFilter); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + } + } + @Override Field getKnnVectorField(String name, float[] vector) { return new KnnFloatVectorField(name, vector); From b7f0ed23ecf94d0d53263dc7ef1e60e284b2b298 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Fri, 22 Mar 2024 21:26:27 +0000 Subject: [PATCH 2/5] Also timeout exact searches --- .../lucene/search/AbstractKnnVectorQuery.java | 34 ++++++++++++++----- .../lucene/search/TestKnnByteVectorQuery.java | 9 ++++- .../search/TestKnnFloatVectorQuery.java | 9 ++++- ...iversifyingChildrenByteKnnVectorQuery.java | 8 ++++- ...versifyingChildrenFloatKnnVectorQuery.java | 8 ++++- ...TestParentBlockJoinByteKnnVectorQuery.java | 7 ++++ ...estParentBlockJoinFloatKnnVectorQuery.java | 7 ++++ 7 files changed, 69 insertions(+), 13 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index c294ef233b80..16e88d858550 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.TopKnnCollectorManager; import org.apache.lucene.util.BitSet; @@ -81,7 +82,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } - KnnCollectorManager knnCollectorManager = + TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager( getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); @@ -101,9 +102,11 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { } private TopDocs searchLeaf( - LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) + LeafReaderContext ctx, + Weight filterWeight, + TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { - TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager); + TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { scoreDoc.doc += ctx.docBase; @@ -113,13 +116,15 @@ private TopDocs searchLeaf( } private TopDocs getLeafResults( - LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) + LeafReaderContext ctx, + Weight filterWeight, + TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { Bits liveDocs = ctx.reader().getLiveDocs(); int maxDoc = ctx.reader().maxDoc(); if (filterWeight == null) { - return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager); + return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); } Scorer scorer = filterWeight.scorer(ctx); @@ -133,17 +138,23 @@ private TopDocs getLeafResults( if (cost <= k) { // If there are <= k possible matches, short-circuit and perform exact search, since HNSW // must always visit at least k documents - return exactSearch(ctx, new BitSetIterator(acceptDocs, cost)); + return exactSearch( + ctx, + new BitSetIterator(acceptDocs, cost), + timeLimitingKnnCollectorManager.getQueryTimeout()); } // Perform the approximate kNN search // We pass cost + 1 here to account for the edge case when we explore exactly cost vectors - TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager); + TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager); if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) { return results; } else { // We stopped the kNN search because it visited too many nodes, so fall back to exact search - return exactSearch(ctx, new BitSetIterator(acceptDocs, cost)); + return exactSearch( + ctx, + new BitSetIterator(acceptDocs, cost), + timeLimitingKnnCollectorManager.getQueryTimeout()); } } @@ -180,7 +191,8 @@ abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi throws IOException; // We allow this to be overridden so that tests can check what search strategy is used - protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException { FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { @@ -197,6 +209,10 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept ScoreDoc topDoc = queue.top(); int doc; while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (queryTimeout != null && queryTimeout.shouldExit()) { + return NO_RESULTS; + } + boolean advanced = vectorScorer.advanceExact(doc); assert advanced; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index f6c51d39a741..676f96857e3a 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.util.TestVectorUtil; @@ -109,10 +110,15 @@ public void testTimeout() throws IOException { IndexSearcher searcher = newSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); + AbstractKnnVectorQuery exactQuery = + getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery()); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + assertEquals(3, searcher.count(exactQuery)); // Same for exact search searcher.setTimeout(() -> true); // Immediately timeout assertEquals(0, searcher.count(query)); // Expect no results with the timeout + assertEquals(0, searcher.count(exactQuery)); // Same for exact search } } @@ -123,7 +129,8 @@ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) } @Override - protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) { + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { throw new UnsupportedOperationException("exact search is not supported"); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index c90656dd437f..d1ca3f7d7410 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; @@ -258,10 +259,15 @@ public void testTimeout() throws IOException { IndexSearcher searcher = newSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); + AbstractKnnVectorQuery exactQuery = + getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery()); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + assertEquals(3, searcher.count(exactQuery)); // Same for exact search searcher.setTimeout(() -> true); // Immediately timeout assertEquals(0, searcher.count(query)); // Expect no results with the timeout + assertEquals(0, searcher.count(exactQuery)); // Same for exact search } } @@ -272,7 +278,8 @@ public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) } @Override - protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) { + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { throw new UnsupportedOperationException("exact search is not supported"); } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index ef0a30282098..9d628d8d44f9 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; @@ -77,7 +78,8 @@ public DiversifyingChildrenByteKnnVectorQuery( } @Override - protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException { ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field); if (byteVectorValues == null) { @@ -102,6 +104,10 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept HitQueue queue = new HitQueue(queueSize, true); ScoreDoc topDoc = queue.top(); while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) { + if (queryTimeout != null && queryTimeout.shouldExit()) { + return NO_RESULTS; + } + float score = vectorScorer.score(); if (score > topDoc.score) { topDoc.score = score; diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 75a4a1a62ae3..ac1def0823a1 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; @@ -77,7 +78,8 @@ public DiversifyingChildrenFloatKnnVectorQuery( } @Override - protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field); if (floatVectorValues == null) { @@ -102,6 +104,10 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept HitQueue queue = new HitQueue(queueSize, true); ScoreDoc topDoc = queue.top(); while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) { + if (queryTimeout != null && queryTimeout.shouldExit()) { + return NO_RESULTS; + } + float score = vectorScorer.score(); if (score > topDoc.score) { topDoc.score = score; diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java index 7f34f8153e9e..91f807440aa3 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java @@ -25,6 +25,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; @@ -77,10 +78,16 @@ public void testTimeout() throws IOException { Query query = new DiversifyingChildrenByteKnnVectorQuery( "field", new byte[] {1, 2}, null, 2, parentFilter); + Query exactQuery = + new DiversifyingChildrenByteKnnVectorQuery( + "field", new byte[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + assertEquals(3, searcher.count(exactQuery)); // Same for exact search searcher.setTimeout(() -> true); // Immediately timeout assertEquals(0, searcher.count(query)); // Expect no results with the timeout + assertEquals(0, searcher.count(exactQuery)); // Same for exact search } } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java index 7be41a73c267..e8f246a53f98 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java @@ -32,6 +32,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; @@ -113,10 +114,16 @@ public void testTimeout() throws IOException { Query query = new DiversifyingChildrenFloatKnnVectorQuery( "field", new float[] {1, 2}, null, 2, parentFilter); + Query exactQuery = + new DiversifyingChildrenFloatKnnVectorQuery( + "field", new float[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter); + assertEquals(2, searcher.count(query)); // Expect some results without timeout + assertEquals(3, searcher.count(exactQuery)); // Same for exact search searcher.setTimeout(() -> true); // Immediately timeout assertEquals(0, searcher.count(query)); // Expect no results with the timeout + assertEquals(0, searcher.count(exactQuery)); // Same for exact search } } From 31a2643242d83b0359e950904030a15a130913f6 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Mon, 25 Mar 2024 10:17:42 +0000 Subject: [PATCH 3/5] Return partial KNN results --- .../lucene/search/AbstractKnnVectorQuery.java | 22 +++++++++---------- .../TimeLimitingKnnCollectorManager.java | 14 +++++++----- ...iversifyingChildrenByteKnnVectorQuery.java | 7 ++++-- ...versifyingChildrenFloatKnnVectorQuery.java | 7 ++++-- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 16e88d858550..1dec9da30428 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -134,27 +134,24 @@ private TopDocs getLeafResults( BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc); final int cost = acceptDocs.cardinality(); + QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout(); if (cost <= k) { // If there are <= k possible matches, short-circuit and perform exact search, since HNSW // must always visit at least k documents - return exactSearch( - ctx, - new BitSetIterator(acceptDocs, cost), - timeLimitingKnnCollectorManager.getQueryTimeout()); + return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout); } // Perform the approximate kNN search // We pass cost + 1 here to account for the edge case when we explore exactly cost vectors TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager); - if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) { + if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO + // Return partial results only when timeout is met + || (queryTimeout != null && queryTimeout.shouldExit())) { return results; } else { // We stopped the kNN search because it visited too many nodes, so fall back to exact search - return exactSearch( - ctx, - new BitSetIterator(acceptDocs, cost), - timeLimitingKnnCollectorManager.getQueryTimeout()); + return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout); } } @@ -206,11 +203,14 @@ protected TopDocs exactSearch( } final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); + TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; ScoreDoc topDoc = queue.top(); int doc; while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + // Mark results as partial if timeout is met if (queryTimeout != null && queryTimeout.shouldExit()) { - return NO_RESULTS; + relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + break; } boolean advanced = vectorScorer.advanceExact(doc); @@ -234,7 +234,7 @@ protected TopDocs exactSearch( topScoreDocs[i] = queue.pop(); } - TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO); + TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation); return new TopDocs(totalHits, topScoreDocs); } diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java index 20c0d6bb4dfe..a7711d4aad20 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -80,11 +80,15 @@ public float minCompetitiveSimilarity() { @Override public TopDocs topDocs() { - if (queryTimeout.shouldExit()) { - // Quickly exit - return TopDocsCollector.EMPTY_TOPDOCS; - } - return collector.topDocs(); + TopDocs docs = collector.topDocs(); + + // Mark results as partial if timeout is met + TotalHits.Relation relation = + queryTimeout.shouldExit() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : docs.totalHits.relation; + + return new TopDocs(new TotalHits(docs.totalHits.value, relation), docs.scoreDocs); } }; } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index 9d628d8d44f9..f42f456cf8cd 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -102,10 +102,13 @@ protected TopDocs exactSearch( fi.getVectorSimilarityFunction()); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); + TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; ScoreDoc topDoc = queue.top(); while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) { + // Mark results as partial if timeout is met if (queryTimeout != null && queryTimeout.shouldExit()) { - return NO_RESULTS; + relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + break; } float score = vectorScorer.score(); @@ -126,7 +129,7 @@ protected TopDocs exactSearch( topScoreDocs[i] = queue.pop(); } - TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO); + TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation); return new TopDocs(totalHits, topScoreDocs); } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index ac1def0823a1..8153103f9505 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -102,10 +102,13 @@ protected TopDocs exactSearch( fi.getVectorSimilarityFunction()); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); + TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; ScoreDoc topDoc = queue.top(); while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) { + // Mark results as partial if timeout is met if (queryTimeout != null && queryTimeout.shouldExit()) { - return NO_RESULTS; + relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + break; } float score = vectorScorer.score(); @@ -126,7 +129,7 @@ protected TopDocs exactSearch( topScoreDocs[i] = queue.pop(); } - TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO); + TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation); return new TopDocs(totalHits, topScoreDocs); } From 6908af479fd93e858e83dc58b2a83292ef9568fa Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Tue, 26 Mar 2024 21:27:29 +0000 Subject: [PATCH 4/5] Add tests for partial KNN results - Refactor tests to base classes - Also timeout exact searches in Lucene99HnswVectorsReader --- .../lucene99/Lucene99HnswVectorsReader.java | 6 +++ .../search/BaseKnnVectorQueryTestCase.java | 46 ++++++++++++++++++ .../lucene/search/TestKnnByteVectorQuery.java | 19 -------- .../search/TestKnnFloatVectorQuery.java | 19 -------- ...ParentBlockJoinKnnVectorQueryTestCase.java | 48 +++++++++++++++++++ ...TestParentBlockJoinByteKnnVectorQuery.java | 24 ---------- ...estParentBlockJoinFloatKnnVectorQuery.java | 24 ---------- 7 files changed, 100 insertions(+), 86 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index efb51c963e0f..ff15f4903ad6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -251,6 +251,9 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits // and collect them for (int i = 0; i < scorer.maxOrd(); i++) { if (acceptedOrds == null || acceptedOrds.get(i)) { + if (knnCollector.earlyTerminated()) { + break; + } knnCollector.incVisitedCount(1); knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); } @@ -279,6 +282,9 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits // and collect them for (int i = 0; i < scorer.maxOrd(); i++) { if (acceptedOrds == null || acceptedOrds.get(i)) { + if (knnCollector.earlyTerminated()) { + break; + } knnCollector.incVisitedCount(1); knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 5591a9059c9a..cc3a8c7ce725 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; @@ -762,6 +763,34 @@ public void testBitSetQuery() throws IOException { } } + /** Test that the query times out correctly. */ + public void testTimeout() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); + AbstractKnnVectorQuery exactQuery = + getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery()); + + assertEquals(2, searcher.count(query)); // Expect some results without timeout + assertEquals(3, searcher.count(exactQuery)); // Same for exact search + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + assertEquals(0, searcher.count(exactQuery)); // Same for exact search + + searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc + // Note: This depends on the HNSW graph having just one layer, + // would be 0 in case of multiple layers + assertEquals(1, searcher.count(query)); // Expect only 1 result + + searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 doc + assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result + } + } + /** Creates a new directory and adds documents with the given vectors as kNN vector fields */ Directory getIndexStore(String field, float[]... contents) throws IOException { return getIndexStore(field, VectorSimilarityFunction.EUCLIDEAN, contents); @@ -949,4 +978,21 @@ public int hashCode() { return 31 * classHash() + docs.hashCode(); } } + + private static class CountingQueryTimeout implements QueryTimeout { + private int remaining; + + public CountingQueryTimeout(int count) { + remaining = count; + } + + @Override + public boolean shouldExit() { + if (remaining > 0) { + remaining--; + return false; + } + return true; + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 676f96857e3a..4dc3d385b087 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -103,25 +103,6 @@ public void testVectorEncodingMismatch() throws IOException { } } - public void testTimeout() throws IOException { - try (Directory indexStore = - getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); - IndexReader reader = DirectoryReader.open(indexStore)) { - IndexSearcher searcher = newSearcher(reader); - - AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); - AbstractKnnVectorQuery exactQuery = - getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery()); - - assertEquals(2, searcher.count(query)); // Expect some results without timeout - assertEquals(3, searcher.count(exactQuery)); // Same for exact search - - searcher.setTimeout(() -> true); // Immediately timeout - assertEquals(0, searcher.count(query)); // Expect no results with the timeout - assertEquals(0, searcher.count(exactQuery)); // Same for exact search - } - } - private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index d1ca3f7d7410..6ecc758c0e49 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -252,25 +252,6 @@ public void testDocAndScoreQueryBasics() throws IOException { } } - public void testTimeout() throws IOException { - try (Directory indexStore = - getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); - IndexReader reader = DirectoryReader.open(indexStore)) { - IndexSearcher searcher = newSearcher(reader); - - AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 2); - AbstractKnnVectorQuery exactQuery = - getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, new MatchAllDocsQuery()); - - assertEquals(2, searcher.count(query)); // Expect some results without timeout - assertEquals(3, searcher.count(exactQuery)); // Same for exact search - - searcher.setTimeout(() -> true); // Immediately timeout - assertEquals(0, searcher.count(query)); // Expect no results with the timeout - assertEquals(0, searcher.count(exactQuery)); // Same for exact search - } - } - private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { diff --git a/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java b/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java index b7f3f5f9787a..0ca5e7a65869 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; @@ -283,6 +284,36 @@ public void testSkewedIndex() throws IOException { } } + /** Test that the query times out correctly. */ + public void testTimeout() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + BitSetProducer parentFilter = parentFilter(reader); + IndexSearcher searcher = newSearcher(reader); + + Query query = getParentJoinKnnQuery("field", new float[] {1, 2}, null, 2, parentFilter); + Query exactQuery = + getParentJoinKnnQuery( + "field", new float[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter); + + assertEquals(2, searcher.count(query)); // Expect some results without timeout + assertEquals(3, searcher.count(exactQuery)); // Same for exact search + + searcher.setTimeout(() -> true); // Immediately timeout + assertEquals(0, searcher.count(query)); // Expect no results with the timeout + assertEquals(0, searcher.count(exactQuery)); // Same for exact search + + searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 parent + // Note: This depends on the HNSW graph having just one layer, + // would be 0 in case of multiple layers + assertEquals(1, searcher.count(query)); // Expect only 1 result + + searcher.setTimeout(new CountingQueryTimeout(1)); // Only score 1 parent + assertEquals(1, searcher.count(exactQuery)); // Expect only 1 result + } + } + Directory getIndexStore(String field, float[]... contents) throws IOException { Directory indexStore = newDirectory(); RandomIndexWriter writer = @@ -352,4 +383,21 @@ void assertScorerResults( assertEquals(idToScore.get(actualId), scorer.score(), 0.0001); } } + + private static class CountingQueryTimeout implements QueryTimeout { + private int remaining; + + public CountingQueryTimeout(int count) { + remaining = count; + } + + @Override + public boolean shouldExit() { + if (remaining > 0) { + remaining--; + return false; + } + return true; + } + } } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java index 91f807440aa3..5b70763d044b 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java @@ -25,7 +25,6 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; @@ -68,29 +67,6 @@ public void testVectorEncodingMismatch() throws IOException { } } - public void testTimeout() throws IOException { - try (Directory indexStore = - getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); - IndexReader reader = DirectoryReader.open(indexStore)) { - BitSetProducer parentFilter = parentFilter(reader); - IndexSearcher searcher = newSearcher(reader); - - Query query = - new DiversifyingChildrenByteKnnVectorQuery( - "field", new byte[] {1, 2}, null, 2, parentFilter); - Query exactQuery = - new DiversifyingChildrenByteKnnVectorQuery( - "field", new byte[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter); - - assertEquals(2, searcher.count(query)); // Expect some results without timeout - assertEquals(3, searcher.count(exactQuery)); // Same for exact search - - searcher.setTimeout(() -> true); // Immediately timeout - assertEquals(0, searcher.count(query)); // Expect no results with the timeout - assertEquals(0, searcher.count(exactQuery)); // Same for exact search - } - } - private static byte[] fromFloat(float[] queryVector) { byte[] query = new byte[queryVector.length]; for (int i = 0; i < queryVector.length; i++) { diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java index e8f246a53f98..fee519a1e1d6 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java @@ -32,7 +32,6 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; @@ -104,29 +103,6 @@ d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) { } } - public void testTimeout() throws IOException { - try (Directory indexStore = - getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); - IndexReader reader = DirectoryReader.open(indexStore)) { - BitSetProducer parentFilter = parentFilter(reader); - IndexSearcher searcher = newSearcher(reader); - - Query query = - new DiversifyingChildrenFloatKnnVectorQuery( - "field", new float[] {1, 2}, null, 2, parentFilter); - Query exactQuery = - new DiversifyingChildrenFloatKnnVectorQuery( - "field", new float[] {1, 2}, new MatchAllDocsQuery(), 10, parentFilter); - - assertEquals(2, searcher.count(query)); // Expect some results without timeout - assertEquals(3, searcher.count(exactQuery)); // Same for exact search - - searcher.setTimeout(() -> true); // Immediately timeout - assertEquals(0, searcher.count(query)); // Expect no results with the timeout - assertEquals(0, searcher.count(exactQuery)); // Same for exact search - } - } - @Override Field getKnnVectorField(String name, float[] vector) { return new KnnFloatVectorField(name, vector); From 8ad040ab75a073e71f4299040b2c613b78fc1e14 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Tue, 2 Apr 2024 09:00:54 +0000 Subject: [PATCH 5/5] Add CHANGES.txt entry and fix some comments --- lucene/CHANGES.txt | 3 +++ .../src/java/org/apache/lucene/search/IndexSearcher.java | 5 ++++- .../lucene/search/TimeLimitingKnnCollectorManager.java | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index dcce957ff542..3d8b797fdd88 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -236,6 +236,9 @@ Improvements implementation is the ConcurrentMergeScheduler and the Lucene99HnswVectorsFormat will use it if no other executor is provided. (Ben Trent) +* GITHUB#13202: Early terminate graph and exact searches of AbstractKnnVectorQuery to follow timeout set from + IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh) + Optimizations --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index bb937355b4a7..3c2a79e196a6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -483,7 +483,10 @@ public TopDocs searchAfter(ScoreDoc after, Query query, int numHits) throws IOEx return search(query, manager); } - /** Get a {@link QueryTimeout} for all searches that run through this {@link IndexSearcher}. */ + /** + * Get the configured {@link QueryTimeout} for all searches that run through this {@link + * IndexSearcher}, or {@code null} if not set. + */ public QueryTimeout getTimeout() { return this.queryTimeout; } diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java index a7711d4aad20..c92fe0f9e34a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -31,7 +31,7 @@ public TimeLimitingKnnCollectorManager(KnnCollectorManager delegate, QueryTimeou this.queryTimeout = timeout; } - /** Get the {@link QueryTimeout} for terminating graph searches. */ + /** Get the configured {@link QueryTimeout} for terminating graph and exact searches. */ public QueryTimeout getQueryTimeout() { return queryTimeout; }