Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a post-collection hook to LeafCollector. #12380

Merged
merged 8 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ API Changes

New Features
---------------------
(No changes)

* GITHUB#12383: Introduced LeafCollector#finish, a hook that runs after
collection has finished running on a leaf. (Adrien Grand)

Improvements
---------------------
Expand Down
71 changes: 35 additions & 36 deletions lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ private static class NoScoreCachingCollector extends CachingCollector {
List<LeafReaderContext> contexts;
List<int[]> docs;
int maxDocsToCache;
NoScoreCachingLeafCollector lastCollector;

NoScoreCachingCollector(Collector in, int maxDocsToCache) {
super(in);
Expand All @@ -76,7 +75,7 @@ private static class NoScoreCachingCollector extends CachingCollector {
}

protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache) {
return new NoScoreCachingLeafCollector(in, maxDocsToCache);
return new NoScoreCachingLeafCollector(in, maxDocsToCache, this);
}

// note: do *not* override needScore to say false. Just because we aren't caching the score
Expand All @@ -85,13 +84,12 @@ protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache)

@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
postCollection();
final LeafCollector in = this.in.getLeafCollector(context);
if (contexts != null) {
contexts.add(context);
}
if (maxDocsToCache >= 0) {
return lastCollector = wrap(in, maxDocsToCache);
if (contexts != null) {
contexts.add(context);
}
return wrap(in, maxDocsToCache);
} else {
return in;
}
Expand All @@ -103,33 +101,16 @@ protected void invalidate() {
this.docs = null;
}

protected void postCollect(NoScoreCachingLeafCollector collector) {
final int[] docs = collector.cachedDocs();
maxDocsToCache -= docs.length;
this.docs.add(docs);
}

private void postCollection() {
if (lastCollector != null) {
if (!lastCollector.hasCache()) {
invalidate();
} else {
postCollect(lastCollector);
}
lastCollector = null;
}
}

protected void collect(LeafCollector collector, int i) throws IOException {
final int[] docs = this.docs.get(i);
for (int doc : docs) {
collector.collect(doc);
}
collector.finish();
}

@Override
public void replay(Collector other) throws IOException {
postCollection();
if (!isCached()) {
throw new IllegalStateException(
"cannot replay: cache was cleared because too much RAM was required");
Expand All @@ -154,14 +135,7 @@ private static class ScoreCachingCollector extends NoScoreCachingCollector {

@Override
protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache) {
return new ScoreCachingLeafCollector(in, maxDocsToCache);
}

@Override
protected void postCollect(NoScoreCachingLeafCollector collector) {
final ScoreCachingLeafCollector coll = (ScoreCachingLeafCollector) collector;
super.postCollect(coll);
scores.add(coll.cachedScores());
return new ScoreCachingLeafCollector(in, maxDocsToCache, this);
}

/**
Expand Down Expand Up @@ -191,12 +165,15 @@ protected void collect(LeafCollector collector, int i) throws IOException {
private class NoScoreCachingLeafCollector extends FilterLeafCollector {

final int maxDocsToCache;
final NoScoreCachingCollector collector;
int[] docs;
int docCount;

NoScoreCachingLeafCollector(LeafCollector in, int maxDocsToCache) {
NoScoreCachingLeafCollector(
LeafCollector in, int maxDocsToCache, NoScoreCachingCollector collector) {
super(in);
this.maxDocsToCache = maxDocsToCache;
this.collector = collector;
docs = new int[Math.min(maxDocsToCache, INITIAL_ARRAY_SIZE)];
docCount = 0;
}
Expand Down Expand Up @@ -235,6 +212,21 @@ public void collect(int doc) throws IOException {
super.collect(doc);
}

protected void postCollect() {
final int[] docs = cachedDocs();
collector.maxDocsToCache -= docs.length;
collector.docs.add(docs);
}

@Override
public void finish() {
if (!hasCache()) {
collector.invalidate();
} else {
postCollect();
}
}

boolean hasCache() {
return docs != null;
}
Expand All @@ -249,8 +241,9 @@ private class ScoreCachingLeafCollector extends NoScoreCachingLeafCollector {
Scorable scorer;
float[] scores;

ScoreCachingLeafCollector(LeafCollector in, int maxDocsToCache) {
super(in, maxDocsToCache);
ScoreCachingLeafCollector(
LeafCollector in, int maxDocsToCache, ScoreCachingCollector collector) {
super(in, maxDocsToCache, collector);
scores = new float[docs.length];
}

Expand Down Expand Up @@ -281,6 +274,12 @@ protected void buffer(int doc) throws IOException {
float[] cachedScores() {
return docs == null ? null : ArrayUtil.copyOfSubArray(scores, 0, docCount);
}

@Override
protected void postCollect() {
super.postCollect();
((ScoreCachingCollector) collector).scores.add(cachedScores());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public void collect(int doc) throws IOException {
in.collect(doc);
}

@Override
public void finish() throws IOException {
in.finish();
}

@Override
public String toString() {
String name = getClass().getSimpleName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,9 @@ protected void search(List<LeafReaderContext> leaves, Weight weight, Collector c
partialResult = true;
}
}
// Note: this is called if collection ran successfully, including the above special cases of
zhaih marked this conversation as resolved.
Show resolved Hide resolved
// CollectionTerminatedException and TimeExceededException, but no other exception.
leafCollector.finish();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether it worths passing in the exceptions if any in case of early termination, but I can't think of a concrete example of how it might be useful right now (maybe user want a faster finish step in case of early terminated by time?), maybe we can add it later if there's a real need?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a use-case either. Another argument could be that CollectionTerminatedException is only one way to skip hits, LeafCollector#competitiveIterator and Scorer#setMinCompetitiveScore are other ones, why would we give more information to finish() for one way of skipping and not for other ones?

Copy link
Contributor

@iverase iverase Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I notice in the case there is no doc of interest, it won't be called (see continue statement), I wonder if we should call it even in that case? we are not building a leaf collector in that case, sorry.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be an opportunity for capturing statistics about how often time-limitation is applied?

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,13 @@ public interface LeafCollector {
default DocIdSetIterator competitiveIterator() throws IOException {
return null;
}

/**
* Hook that gets called once the leaf that is associated with this collector has finished
* collecting successfully, including when a {@link CollectionTerminatedException} is thrown. This
* is typically useful to compile data that has been collected on this leaf, e.g. to convert facet
* counts on leaf ordinals to facet counts on global ordinals. The default implementation does
* nothing.
*/
default void finish() throws IOException {}
}
10 changes: 10 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ public void collect(int doc) throws IOException {
} catch (
@SuppressWarnings("unused")
CollectionTerminatedException e) {
collectors[i].finish();
collectors[i] = null;
if (allCollectorsTerminated()) {
throw new CollectionTerminatedException();
Expand All @@ -232,6 +233,15 @@ public void collect(int doc) throws IOException {
}
}

@Override
public void finish() throws IOException {
for (LeafCollector collector : collectors) {
if (collector != null) {
collector.finish();
}
}
}

private boolean allCollectorsTerminated() {
for (int i = 0; i < collectors.length; i++) {
if (collectors[i] != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void testBasic() throws Exception {
for (int i = 0; i < 1000; i++) {
acc.collect(i);
}
acc.finish();

// now replay them
cc.replay(
Expand Down Expand Up @@ -127,6 +128,7 @@ public void testNoWrappedCollector() throws Exception {
acc.collect(0);

assertTrue(cc.isCached());
acc.finish();
cc.replay(new NoOpCollector());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -198,6 +199,7 @@ private void doQueryFirstScoringSingleDim(

docID = baseApproximation.nextDoc();
}
finish(collector, Collections.singleton(dim));
}

/**
Expand Down Expand Up @@ -334,6 +336,8 @@ protected boolean lessThan(DocsAndCost a, DocsAndCost b) {

docID = baseApproximation.nextDoc();
}

finish(collector, sidewaysDims);
}

private static int advanceIfBehind(int docID, DocIdSetIterator iterator) throws IOException {
Expand Down Expand Up @@ -552,6 +556,7 @@ private void doDrillDownAdvanceScoring(

nextChunkStart += CHUNK;
}
finish(collector, Arrays.asList(dims));
}

private void doUnionScoring(Bits acceptDocs, LeafCollector collector, DocsAndCost[] dims)
Expand Down Expand Up @@ -706,6 +711,8 @@ private void doUnionScoring(Bits acceptDocs, LeafCollector collector, DocsAndCos

nextChunkStart += CHUNK;
}

finish(collector, Arrays.asList(dims));
}

private void collectHit(LeafCollector collector, DocsAndCost[] dims) throws IOException {
Expand Down Expand Up @@ -757,6 +764,16 @@ private void collectNearMiss(LeafCollector sidewaysCollector) throws IOException
sidewaysCollector.collect(collectDocID);
}

private void finish(LeafCollector collector, Collection<DocsAndCost> dims) throws IOException {
collector.finish();
if (drillDownLeafCollector != null) {
drillDownLeafCollector.finish();
}
for (DocsAndCost dim : dims) {
dim.sidewaysLeafCollector.finish();
}
}

private void setScorer(LeafCollector mainCollector, Scorable scorer) throws IOException {
mainCollector.setScorer(scorer);
if (drillDownLeafCollector != null) {
Expand Down
19 changes: 9 additions & 10 deletions lucene/facet/src/java/org/apache/lucene/facet/FacetsCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ public final boolean getKeepScores() {

/** Returns the documents matched by the query, one {@link MatchingDocs} per visited segment. */
public List<MatchingDocs> getMatchingDocs() {
if (docsBuilder != null) {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
docsBuilder = null;
scores = null;
context = null;
}

return matchingDocs;
}

Expand Down Expand Up @@ -139,9 +132,7 @@ public final void setScorer(Scorable scorer) throws IOException {

@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
if (docsBuilder != null) {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
}
assert docsBuilder == null;
docsBuilder = new DocIdSetBuilder(context.reader().maxDoc());
totalHits = 0;
if (keepScores) {
Expand All @@ -150,6 +141,14 @@ protected void doSetNextReader(LeafReaderContext context) throws IOException {
this.context = context;
}

@Override
public void finish() throws IOException {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
docsBuilder = null;
scores = null;
context = null;
}

/** Utility method, to search and also collect all hits into the provided {@link Collector}. */
public static TopDocs search(IndexSearcher searcher, Query q, int n, Collector fc)
throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ public TopGroups<?> getTopGroups(
// if (queueFull) {
// System.out.println("getTopGroups groupOffset=" + groupOffset + " topNGroups=" + topNGroups);
// }
if (subDocUpto != 0) {
processGroup();
}
if (groupOffset >= groupQueue.size()) {
return null;
}
Expand Down Expand Up @@ -472,9 +469,6 @@ public void collect(int doc) throws IOException {

@Override
protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
if (subDocUpto != 0) {
processGroup();
}
subDocUpto = 0;
docBase = readerContext.docBase;
// System.out.println("setNextReader base=" + docBase + " r=" + readerContext.reader);
Expand All @@ -492,6 +486,13 @@ protected void doSetNextReader(LeafReaderContext readerContext) throws IOExcepti
}
}

@Override
public void finish() throws IOException {
if (subDocUpto != 0) {
processGroup();
}
}

@Override
public ScoreMode scoreMode() {
return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ protected GroupFacetCollector(String groupField, String facetField, BytesRef fac
*/
public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean orderByCount)
throws IOException {
if (segmentFacetCounts != null) {
segmentResults.add(createSegmentResult());
segmentFacetCounts = null; // reset
}

int totalCount = 0;
int missingCount = 0;
SegmentResultPriorityQueue segments = new SegmentResultPriorityQueue(segmentResults.size());
Expand Down Expand Up @@ -109,6 +104,12 @@ public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean or
return facetResult;
}

@Override
public void finish() throws IOException {
segmentResults.add(createSegmentResult());
segmentFacetCounts = null;
}

protected abstract SegmentResult createSegmentResult() throws IOException;

@Override
Expand Down
Loading