diff --git a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java index 5ae6cc739c362..751c1cd8bfbe7 100644 --- a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java @@ -33,13 +33,11 @@ import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.FieldDoc; -import org.apache.lucene.search.FilterCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TermQuery; @@ -238,7 +236,15 @@ private SimpleTopDocsCollectorContext(IndexReader reader, this.sortAndFormats = sortAndFormats; final TopDocsCollector topDocsCollector; - if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + + if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) + && hasInfMaxScore(query)) { + // disable max score optimization since we have a mandatory clause + // that doesn't track the maximum score + topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE); + topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); + totalHitsSupplier = () -> topDocsSupplier.get().totalHits; + } else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { // don't compute hit counts via the collector topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); @@ -274,27 +280,7 @@ private SimpleTopDocsCollectorContext(IndexReader reader, maxScoreSupplier = () -> Float.NaN; } - final Collector collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); - if (sortAndFormats == null || - SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) { - if (hasInfMaxScore(query)) { - // disable max score optimization since we have a mandatory clause - // that doesn't track the maximum score - this.collector = new FilterCollector(collector) { - @Override - public ScoreMode scoreMode() { - if (in.scoreMode() == ScoreMode.TOP_SCORES) { - return ScoreMode.COMPLETE; - } - return in.scoreMode(); - } - }; - } else { - this.collector = collector; - } - } else { - this.collector = collector; - } + this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); } diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index 192e54506495f..2190e573707e6 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -606,15 +606,16 @@ public void testDisableTopScoreCollection() throws Exception { .build(); context.parsedQuery(new ParsedQuery(q)); - context.setSize(10); + context.setSize(3); + context.trackTotalHitsUpTo(3); + TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, reader, false); assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); QueryPhase.execute(context, contextSearcher, checkCancelled -> {}); assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); - assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(5)); - + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW })); @@ -623,7 +624,7 @@ public void testDisableTopScoreCollection() throws Exception { assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES); QueryPhase.execute(context, contextSearcher, checkCancelled -> {}); assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); - assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(5)); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); reader.close();