Skip to content

Commit

Permalink
Clean up some aggregation tests (backport of #66044) (#66134)
Browse files Browse the repository at this point in the history
This rewrites two tests for aggregations to use `AggregatorTestCase`'s
simpler way of making `Aggregator`s, allowing us to remove a ctor on
`ProductionAggregationContext` that we weren't happy about. Now there is
only a single test call to `ProductionAggregationContext` and we can
remove that soon.
  • Loading branch information
nik9000 authored Dec 9, 2020
1 parent fb62b35 commit e4096ae
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 153 deletions.
15 changes: 14 additions & 1 deletion server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
Expand Down Expand Up @@ -104,6 +105,7 @@
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.internal.SubSearchContext;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.query.QueryPhase;
Expand Down Expand Up @@ -933,7 +935,18 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
context.terminateAfter(source.terminateAfter());
if (source.aggregations() != null && includeAggregations) {
AggregationContext aggContext = new ProductionAggregationContext(context, multiBucketConsumerService.create());
AggregationContext aggContext = new ProductionAggregationContext(
context.getQueryShardContext(),
context.query() == null ? new MatchAllDocsQuery() : context.query(),
context.getProfilers() == null ? null : context.getProfilers().getAggregationProfiler(),
multiBucketConsumerService.create(),
() -> new SubSearchContext(context).parsedQuery(context.parsedQuery()).fetchFieldsContext(context.fetchFieldsContext()),
context::addReleasable,
context.bitsetFilterCache(),
context.indexShard().shardId().hashCode(),
context::getRelativeTimeInMillis,
context::isCancelled
);
try {
AggregatorFactories factories = source.aggregations().build(aggContext, null);
context.aggregations(new SearchContextAggregations(factories));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
Expand All @@ -41,7 +40,6 @@
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.SubSearchContext;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.profile.aggregation.AggregationProfiler;
Expand Down Expand Up @@ -259,21 +257,6 @@ public static class ProductionAggregationContext extends AggregationContext {
private final LongSupplier relativeTimeInMillis;
private final Supplier<Boolean> isCancelled;

public ProductionAggregationContext(SearchContext context, MultiBucketConsumer multiBucketConsumer) {
this( // TODO we'd prefer to not use SearchContext everywhere but we have a bunch of tests that use this now
context.getQueryShardContext(),
context.query() == null ? new MatchAllDocsQuery() : context.query(),
context.getProfilers() == null ? null : context.getProfilers().getAggregationProfiler(),
multiBucketConsumer,
() -> new SubSearchContext(context).parsedQuery(context.parsedQuery()).fetchFieldsContext(context.fetchFieldsContext()),
context::addReleasable,
context.bitsetFilterCache(),
context.indexShard().shardId().hashCode(),
context::getRelativeTimeInMillis,
context::isCancelled
);
}

public ProductionAggregationContext(
QueryShardContext context,
Query topLevelQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,54 +19,109 @@

package org.elasticsearch.search.aggregations;

import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer;
import org.elasticsearch.search.aggregations.support.AggregationContext.ProductionAggregationContext;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.elasticsearch.index.mapper.KeywordFieldMapper.KeywordFieldType;
import org.elasticsearch.script.AggregationScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder;

import java.io.IOException;

import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class AggregationCollectorTests extends ESSingleNodeTestCase {
public class AggregationCollectorTests extends AggregatorTestCase {
public void testTerms() throws IOException {
assertFalse(needsScores(termsBuilder().field("f")));
}

public void testSubTerms() throws IOException {
assertFalse(needsScores(termsBuilder().field("f").subAggregation(new TermsAggregationBuilder("i").field("f"))));
}

public void testScoreConsumingScript() throws IOException {
assertFalse(needsScores(termsBuilder().script(new Script("no_scores"))));
}

public void testNonScoreConsumingScript() throws IOException {
assertTrue(needsScores(termsBuilder().script(new Script("with_scores"))));
}

public void testNeedsScores() throws Exception {
IndexService index = createIndex("idx");
client().prepareIndex("idx", "type", "1").setSource("f", 5).execute().get();
client().admin().indices().prepareRefresh("idx").get();
public void testSubScoreConsumingScript() throws IOException {
assertFalse(needsScores(termsBuilder().field("f").subAggregation(termsBuilder().script(new Script("no_scores")))));
}

// simple field aggregation, no scores needed
String fieldAgg = "{ \"my_terms\": {\"terms\": {\"field\": \"f\"}}}";
assertFalse(needsScores(index, fieldAgg));
public void testSubNonScoreConsumingScript() throws IOException {
assertTrue(needsScores(termsBuilder().field("f").subAggregation(termsBuilder().script(new Script("with_scores")))));
}

// agg on a script => scores are needed
// TODO: can we use a mock script service here?
// String scriptAgg = "{ \"my_terms\": {\"terms\": {\"script\": \"doc['f'].value\"}}}";
// assertTrue(needsScores(index, scriptAgg));
//
// String subScriptAgg = "{ \"my_outer_terms\": { \"terms\": { \"field\": \"f\" }, \"aggs\": " + scriptAgg + "}}";
// assertTrue(needsScores(index, subScriptAgg));
public void testTopHits() throws IOException {
assertTrue(needsScores(new TopHitsAggregationBuilder("h")));
}

// make sure the information is propagated to sub aggregations
String subFieldAgg = "{ \"my_outer_terms\": { \"terms\": { \"field\": \"f\" }, \"aggs\": " + fieldAgg + "}}";
assertFalse(needsScores(index, subFieldAgg));
public void testSubTopHits() throws IOException {
assertTrue(needsScores(termsBuilder().field("f").subAggregation(new TopHitsAggregationBuilder("h"))));
}

// top_hits is a particular example of an aggregation that needs scores
String topHitsAgg = "{ \"my_hits\": {\"top_hits\": {}}}";
assertTrue(needsScores(index, topHitsAgg));
private TermsAggregationBuilder termsBuilder() {
return new TermsAggregationBuilder("t");
}

private boolean needsScores(IndexService index, String agg) throws IOException {
try (XContentParser aggParser = createParser(JsonXContent.jsonXContent, agg)) {
aggParser.nextToken();
final AggregatorFactories factories = AggregatorFactories.parseAggregators(aggParser)
.build(new ProductionAggregationContext(createSearchContext(index), mock(MultiBucketConsumer.class)), null);
final Aggregator[] aggregators = factories.createTopLevelAggregators();
assertEquals(1, aggregators.length);
return aggregators[0].scoreMode().needsScores();
private boolean needsScores(AggregationBuilder builder) throws IOException {
try (
Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
DirectoryReader reader = indexWriter.getReader()
) {
return createAggregator(builder, new IndexSearcher(reader), new KeywordFieldType("f")).scoreMode().needsScores();
}
}

@Override
protected ScriptService getMockScriptService() {
ScriptService scriptService = mock(ScriptService.class);
when(scriptService.compile(any(), any())).then(inv -> {
Script script = (Script) inv.getArguments()[0];
AggregationScript.Factory factory;
switch (script.getIdOrCode()) {
case "no_scores":
factory = (params, lookup) -> new AggregationScript.LeafFactory() {
@Override
public AggregationScript newInstance(LeafReaderContext ctx) throws IOException {
return null;
}

@Override
public boolean needs_score() {
return false;
}
};
break;
case "with_scores":
factory = (params, lookup) -> new AggregationScript.LeafFactory() {
@Override
public AggregationScript newInstance(LeafReaderContext ctx) throws IOException {
return null;
}

@Override
public boolean needs_score() {
return true;
}
};
break;
default:
throw new UnsupportedOperationException();
}
return factory;
});
return scriptService;
}
}

This file was deleted.

Loading

0 comments on commit e4096ae

Please sign in to comment.