Skip to content

Commit

Permalink
Fix NPE on composite aggregation with sub-aggregations that need scor…
Browse files Browse the repository at this point in the history
…es (#28129)

The composite aggregation defers the collection of sub-aggregations to a second pass that visits documents only if they
appear in the top buckets. Though the scorer for sub-aggregations is not set on this second pass and generates an NPE if any sub-aggregation
tries to access the score. This change creates a scorer for the second pass and makes sure that sub-aggs can use it safely to check the score of
the collected documents.
  • Loading branch information
jimczi authored Jan 15, 2018
1 parent ee7eac8 commit bd11e6c
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.RoaringDocIdSet;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
Expand Down Expand Up @@ -87,6 +90,12 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException

// Replay all documents that contain at least one top bucket (collected during the first pass).
grow(keys.size()+1);
final boolean needsScores = needsScores();
Weight weight = null;
if (needsScores) {
Query query = context.query();
weight = context.searcher().createNormalizedWeight(query, true);
}
for (LeafContext context : contexts) {
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
if (docIdSetIterator == null) {
Expand All @@ -95,7 +104,21 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException
final CompositeValuesSource.Collector collector =
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
int docID;
DocIdSetIterator scorerIt = null;
if (needsScores) {
Scorer scorer = weight.scorer(context.ctx);
// We don't need to check if the scorer is null
// since we are sure that there are documents to replay (docIdSetIterator it not empty).
scorerIt = scorer.iterator();
context.subCollector.setScorer(scorer);
}
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (needsScores) {
assert scorerIt.docID() < docID;
scorerIt.advance(docID);
// aggregations should only be replayed on matching documents
assert scorerIt.docID() == docID;
}
collector.collect(docID);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.IndexSettingsModule;
import org.joda.time.DateTimeZone;
Expand Down Expand Up @@ -1065,8 +1067,73 @@ public void testWithKeywordAndDateHistogram() throws IOException {
);
}

private void testSearchCase(Query query,
Sort sort,
public void testWithKeywordAndTopHits() throws Exception {
final List<Map<String, List<Object>>> dataset = new ArrayList<>();
dataset.addAll(
Arrays.asList(
createDocument("keyword", "a"),
createDocument("keyword", "c"),
createDocument("keyword", "a"),
createDocument("keyword", "d"),
createDocument("keyword", "c")
)
);
final Sort sort = new Sort(new SortedSetSortField("keyword", false));
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
() -> {
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
.field("keyword");
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
}, (result) -> {
assertEquals(3, result.getBuckets().size());
assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
assertEquals(2L, result.getBuckets().get(0).getDocCount());
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
assertEquals(2L, result.getBuckets().get(1).getDocCount());
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
assertEquals(1L, result.getBuckets().get(2).getDocCount());
topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 1);
assertEquals(topHits.getHits().getTotalHits(), 1L);;
}
);

testSearchCase(new MatchAllDocsQuery(), sort, dataset,
() -> {
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
.field("keyword");
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
.aggregateAfter(Collections.singletonMap("keyword", "a"))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
}, (result) -> {
assertEquals(2, result.getBuckets().size());
assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
assertEquals(2L, result.getBuckets().get(0).getDocCount());
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
assertEquals(1L, result.getBuckets().get(1).getDocCount());
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 1);
assertEquals(topHits.getHits().getTotalHits(), 1L);
}
);
}

private void testSearchCase(Query query, Sort sort,
List<Map<String, List<Object>>> dataset,
Supplier<CompositeAggregationBuilder> create,
Consumer<InternalComposite> verify) throws IOException {
Expand Down Expand Up @@ -1107,7 +1174,7 @@ private void executeTestCase(boolean reduced,
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
CompositeAggregationBuilder aggregationBuilder = create.get();
if (sort != null) {
CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
assertTrue(aggregator.canEarlyTerminate());
}
final InternalComposite composite;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,27 @@ protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggreg
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
}

/** Create a factory for the given aggregation builder. */

protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
}

/** Create a factory for the given aggregation builder. */
protected AggregatorFactory<?> createAggregatorFactory(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
when(searchContext.aggregations())
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
when(searchContext.query()).thenReturn(query);
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
MapperService mapperService = mapperServiceMock();
Expand Down Expand Up @@ -146,28 +157,38 @@ protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregati
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
}

protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
protected <A extends Aggregator> A createAggregator(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(aggregationBuilder, indexSearcher, indexSettings,
return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
}

protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
}

protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
}

protected <A extends Aggregator> A createAggregator(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
@SuppressWarnings("unchecked")
A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
.create(null, true);
return aggregator;
}
Expand Down Expand Up @@ -262,7 +283,7 @@ protected <A extends InternalAggregation, C extends Aggregator> A search(IndexSe
int maxBucket,
MappedFieldType... fieldTypes) throws IOException {
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
a.preCollection();
searcher.search(query, a);
a.postCollection();
Expand Down Expand Up @@ -310,11 +331,11 @@ protected <A extends InternalAggregation, C extends Aggregator> A searchAndReduc
Query rewritten = searcher.rewrite(query);
Weight weight = searcher.createWeight(rewritten, true, 1f);
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);

for (ShardSearcher subSearcher : subSearchers) {
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes);
C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
a.preCollection();
subSearcher.search(weight, a);
a.postCollection();
Expand Down

0 comments on commit bd11e6c

Please sign in to comment.