diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java index 9612ba2f895bc..3467aaf318baf 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -23,6 +23,9 @@ import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSet; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; import org.apache.lucene.util.RoaringDocIdSet; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; @@ -87,6 +90,12 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException // Replay all documents that contain at least one top bucket (collected during the first pass). grow(keys.size()+1); + final boolean needsScores = needsScores(); + Weight weight = null; + if (needsScores) { + Query query = context.query(); + weight = context.searcher().createNormalizedWeight(query, true); + } for (LeafContext context : contexts) { DocIdSetIterator docIdSetIterator = context.docIdSet.iterator(); if (docIdSetIterator == null) { @@ -95,7 +104,21 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException final CompositeValuesSource.Collector collector = array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector)); int docID; + DocIdSetIterator scorerIt = null; + if (needsScores) { + Scorer scorer = weight.scorer(context.ctx); + // We don't need to check if the scorer is null + // since we are sure that there are documents to replay (docIdSetIterator it not empty). + scorerIt = scorer.iterator(); + context.subCollector.setScorer(scorer); + } while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (needsScores) { + assert scorerIt.docID() < docID; + scorerIt.advance(docID); + // aggregations should only be replayed on matching documents + assert scorerIt.docID() == docID; + } collector.collect(docID); } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregatorTests.java index 339f9bda65a0a..172aebbc0e5dc 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregatorTests.java @@ -50,6 +50,8 @@ import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.elasticsearch.search.aggregations.metrics.tophits.TopHits; +import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.IndexSettingsModule; import org.joda.time.DateTimeZone; @@ -1065,8 +1067,73 @@ public void testWithKeywordAndDateHistogram() throws IOException { ); } - private void testSearchCase(Query query, - Sort sort, + public void testWithKeywordAndTopHits() throws Exception { + final List>> dataset = new ArrayList<>(); + dataset.addAll( + Arrays.asList( + createDocument("keyword", "a"), + createDocument("keyword", "c"), + createDocument("keyword", "a"), + createDocument("keyword", "d"), + createDocument("keyword", "c") + ) + ); + final Sort sort = new Sort(new SortedSetSortField("keyword", false)); + testSearchCase(new MatchAllDocsQuery(), sort, dataset, + () -> { + TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword") + .field("keyword"); + return new CompositeAggregationBuilder("name", Collections.singletonList(terms)) + .subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_")); + }, (result) -> { + assertEquals(3, result.getBuckets().size()); + assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(0).getDocCount()); + TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits"); + assertNotNull(topHits); + assertEquals(topHits.getHits().getHits().length, 2); + assertEquals(topHits.getHits().getTotalHits(), 2L); + assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(1).getDocCount()); + topHits = result.getBuckets().get(1).getAggregations().get("top_hits"); + assertNotNull(topHits); + assertEquals(topHits.getHits().getHits().length, 2); + assertEquals(topHits.getHits().getTotalHits(), 2L); + assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString()); + assertEquals(1L, result.getBuckets().get(2).getDocCount()); + topHits = result.getBuckets().get(2).getAggregations().get("top_hits"); + assertNotNull(topHits); + assertEquals(topHits.getHits().getHits().length, 1); + assertEquals(topHits.getHits().getTotalHits(), 1L);; + } + ); + + testSearchCase(new MatchAllDocsQuery(), sort, dataset, + () -> { + TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword") + .field("keyword"); + return new CompositeAggregationBuilder("name", Collections.singletonList(terms)) + .aggregateAfter(Collections.singletonMap("keyword", "a")) + .subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_")); + }, (result) -> { + assertEquals(2, result.getBuckets().size()); + assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(0).getDocCount()); + TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits"); + assertNotNull(topHits); + assertEquals(topHits.getHits().getHits().length, 2); + assertEquals(topHits.getHits().getTotalHits(), 2L); + assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString()); + assertEquals(1L, result.getBuckets().get(1).getDocCount()); + topHits = result.getBuckets().get(1).getAggregations().get("top_hits"); + assertNotNull(topHits); + assertEquals(topHits.getHits().getHits().length, 1); + assertEquals(topHits.getHits().getTotalHits(), 1L); + } + ); + } + + private void testSearchCase(Query query, Sort sort, List>> dataset, Supplier create, Consumer verify) throws IOException { @@ -1107,7 +1174,7 @@ private void executeTestCase(boolean reduced, IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null); CompositeAggregationBuilder aggregationBuilder = create.get(); if (sort != null) { - CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES); + CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES); assertTrue(aggregator.canEarlyTerminate()); } final InternalComposite composite; diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 137865beb1540..05ca2d40f82ca 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -103,12 +103,22 @@ protected AggregatorFactory createAggregatorFactory(AggregationBuilder aggreg new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes); } - /** Create a factory for the given aggregation builder. */ + protected AggregatorFactory createAggregatorFactory(AggregationBuilder aggregationBuilder, IndexSearcher indexSearcher, IndexSettings indexSettings, MultiBucketConsumer bucketConsumer, MappedFieldType... fieldTypes) throws IOException { + return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes); + } + + /** Create a factory for the given aggregation builder. */ + protected AggregatorFactory createAggregatorFactory(Query query, + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + IndexSettings indexSettings, + MultiBucketConsumer bucketConsumer, + MappedFieldType... fieldTypes) throws IOException { SearchContext searchContext = createSearchContext(indexSearcher, indexSettings); CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService(); when(searchContext.aggregations()) @@ -116,6 +126,7 @@ protected AggregatorFactory createAggregatorFactory(AggregationBuilder aggreg when(searchContext.bigArrays()).thenReturn( new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService) ); + when(searchContext.query()).thenReturn(query); // TODO: now just needed for top_hits, this will need to be revised for other agg unit tests: MapperService mapperService = mapperServiceMock(); when(mapperService.getIndexSettings()).thenReturn(indexSettings); @@ -148,19 +159,20 @@ protected A createAggregator(AggregationBuilder aggregati new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes); } - protected A createAggregator(AggregationBuilder aggregationBuilder, + protected A createAggregator(Query query, + AggregationBuilder aggregationBuilder, IndexSearcher indexSearcher, IndexSettings indexSettings, MappedFieldType... fieldTypes) throws IOException { - return createAggregator(aggregationBuilder, indexSearcher, indexSettings, + return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes); } - protected A createAggregator(AggregationBuilder aggregationBuilder, + protected A createAggregator(Query query, AggregationBuilder aggregationBuilder, IndexSearcher indexSearcher, MultiBucketConsumer bucketConsumer, MappedFieldType... fieldTypes) throws IOException { - return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes); + return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes); } protected A createAggregator(AggregationBuilder aggregationBuilder, @@ -168,8 +180,17 @@ protected A createAggregator(AggregationBuilder aggregati IndexSettings indexSettings, MultiBucketConsumer bucketConsumer, MappedFieldType... fieldTypes) throws IOException { + return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes); + } + + protected A createAggregator(Query query, + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + IndexSettings indexSettings, + MultiBucketConsumer bucketConsumer, + MappedFieldType... fieldTypes) throws IOException { @SuppressWarnings("unchecked") - A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes) + A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes) .create(null, true); return aggregator; } @@ -264,7 +285,7 @@ protected A search(IndexSe int maxBucket, MappedFieldType... fieldTypes) throws IOException { MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket); - C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes); + C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes); a.preCollection(); searcher.search(query, a); a.postCollection(); @@ -312,11 +333,11 @@ protected A searchAndReduc Query rewritten = searcher.rewrite(query); Weight weight = searcher.createWeight(rewritten, true, 1f); MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket); - C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes); + C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes); for (ShardSearcher subSearcher : subSearchers) { MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket); - C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes); + C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes); a.preCollection(); subSearcher.search(weight, a); a.postCollection();