Skip to content

Commit

Permalink
SQL: Enable accurate hit tracking on demand (#39527)
Browse files Browse the repository at this point in the history
Queries that require counting of all hits (COUNT(*) on implicit
group by), now enable accurate hit tracking.

Fix #37971

(cherry picked from commit 265b637)
  • Loading branch information
costin committed Mar 1, 2019
1 parent 1a8cb52 commit 5e88001
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ public void testExplainWithCount() throws IOException {
assertThat(readLine(), startsWith(" }"));
assertThat(readLine(), startsWith(" }"));
assertThat(readLine(), startsWith(" ]"));
assertThat(readLine(), startsWith(" \"track_total_hits\" : 2147483647"));
assertThat(readLine(), startsWith("}]"));
assertEquals("", readLine());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ private static void optimize(QueryContainer query, SearchSourceBuilder builder)
// disable source fetching (only doc values are used)
disableSource(builder);
}
if (query.shouldTrackHits()) {
builder.trackTotalHits(true);
}
}

private static void disableSource(SearchSourceBuilder builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.Literal;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.Order;
import org.elasticsearch.xpack.sql.expression.function.Function;
Expand Down Expand Up @@ -152,7 +151,8 @@ protected PhysicalPlan rule(ProjectExec project) {
queryC.pseudoFunctions(),
new AttributeMap<>(processors),
queryC.sort(),
queryC.limit());
queryC.limit(),
queryC.shouldTrackHits());
return new EsQueryExec(exec.source(), exec.index(), project.output(), clone);
}
return project;
Expand Down Expand Up @@ -180,7 +180,8 @@ protected PhysicalPlan rule(FilterExec plan) {
qContainer.pseudoFunctions(),
qContainer.scalarFunctions(),
qContainer.sort(),
qContainer.limit());
qContainer.limit(),
qContainer.shouldTrackHits());

return exec.with(qContainer);
}
Expand Down Expand Up @@ -391,10 +392,16 @@ private Tuple<QueryContainer, AggPathInput> addAggFunction(GroupByKey groupingAg
if (f instanceof Count) {
Count c = (Count) f;
// COUNT(*) or COUNT(<literal>)
if (c.field() instanceof Literal) {
AggRef ref = groupingAgg == null ?
GlobalCountRef.INSTANCE :
new GroupByRef(groupingAgg.id(), Property.COUNT, null);
if (c.field().foldable()) {
AggRef ref = null;

if (groupingAgg == null) {
ref = GlobalCountRef.INSTANCE;
// if the count points to the total track hits, enable accurate count retrieval
queryC = queryC.withTrackHits();
} else {
ref = new GroupByRef(groupingAgg.id(), Property.COUNT, null);
}

Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(queryC.pseudoFunctions());
pseudoFunctions.put(functionId, groupingAgg);
Expand All @@ -406,7 +413,7 @@ private Tuple<QueryContainer, AggPathInput> addAggFunction(GroupByKey groupingAg
queryC = queryC.with(queryC.aggs().addAgg(leafAgg));
return new Tuple<>(queryC, a);
}
// the only variant left - COUNT(DISTINCT) - will be covered by the else branch below
// the only variant left - COUNT(DISTINCT) - will be covered by the else branch below as it maps to an aggregation
}

AggPathInput aggInput = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
import org.elasticsearch.xpack.sql.querydsl.agg.GroupByKey;
import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg;
import org.elasticsearch.xpack.sql.querydsl.container.GroupByRef.Property;
import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery;
import org.elasticsearch.xpack.sql.querydsl.query.MatchAll;
import org.elasticsearch.xpack.sql.querydsl.query.NestedQuery;
Expand Down Expand Up @@ -81,23 +80,26 @@ public class QueryContainer {

private final Set<Sort> sort;
private final int limit;
private final boolean trackHits;

// computed
private Boolean aggsOnly;
private Boolean customSort;

public QueryContainer() {
this(null, null, null, null, null, null, null, -1);
this(null, null, null, null, null, null, null, -1, false);
}

public QueryContainer(Query query,
Aggs aggs,
List<Tuple<FieldExtraction, ExpressionId>> fields,
public QueryContainer(Query query,
Aggs aggs,
List<Tuple<FieldExtraction,
ExpressionId>> fields,
AttributeMap<Attribute> aliases,
Map<String, GroupByKey> pseudoFunctions,
AttributeMap<Pipe> scalarFunctions,
Set<Sort> sort,
int limit) {
Map<String, GroupByKey> pseudoFunctions,
AttributeMap<Pipe> scalarFunctions,
Set<Sort> sort,
int limit,
boolean trackHits) {
this.query = query;
this.aggs = aggs == null ? Aggs.EMPTY : aggs;
this.fields = fields == null || fields.isEmpty() ? emptyList() : fields;
Expand All @@ -106,6 +108,7 @@ public QueryContainer(Query query,
this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? AttributeMap.emptyAttributeMap() : scalarFunctions;
this.sort = sort == null || sort.isEmpty() ? emptySet() : sort;
this.limit = limit;
this.trackHits = trackHits;
}

/**
Expand Down Expand Up @@ -230,38 +233,46 @@ public boolean hasColumns() {
return fields.size() > 0;
}

public boolean shouldTrackHits() {
return trackHits;
}

//
// copy methods
//

public QueryContainer with(Query q) {
return new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}

public QueryContainer withAliases(AttributeMap<Attribute> a) {
return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}

public QueryContainer withPseudoFunctions(Map<String, GroupByKey> p) {
return new QueryContainer(query, aggs, fields, aliases, p, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, p, scalarFunctions, sort, limit, trackHits);
}

public QueryContainer with(Aggs a) {
return new QueryContainer(query, a, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, a, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}

public QueryContainer withLimit(int l) {
return l == limit ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, l);
return l == limit ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, l, trackHits);
}

public QueryContainer withTrackHits() {
return trackHits ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, true);
}

public QueryContainer withScalarProcessors(AttributeMap<Pipe> procs) {
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, procs, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, procs, sort, limit, trackHits);
}

public QueryContainer addSort(Sort sortable) {
Set<Sort> sort = new LinkedHashSet<>(this.sort);
sort.add(sortable);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}

private String aliasName(Attribute attr) {
Expand All @@ -287,7 +298,7 @@ private Tuple<QueryContainer, FieldExtraction> nestedHitFieldRef(FieldAttribute
attr.field().isAggregatable(), attr.parent().name());
nestedRefs.add(nestedFieldRef);

return new Tuple<>(new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit),
return new Tuple<>(new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits),
nestedFieldRef);
}

Expand Down Expand Up @@ -390,7 +401,7 @@ public QueryContainer addColumn(FieldExtraction ref, Attribute attr) {
ExpressionId id = attr instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) attr).innerId() : attr.id();
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, id)), aliases, pseudoFunctions,
scalarFunctions,
sort, limit);
sort, limit, trackHits);
}

public AttributeMap<Pipe> scalarFunctions() {
Expand All @@ -401,16 +412,6 @@ public AttributeMap<Pipe> scalarFunctions() {
// agg methods
//

public QueryContainer addAggCount(GroupByKey group, ExpressionId functionId) {
FieldExtraction ref = group == null ? GlobalCountRef.INSTANCE : new GroupByRef(group.id(), Property.COUNT, null);
Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(this.pseudoFunctions);
pseudoFunctions.put(functionId.toString(), group);
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, functionId)),
aliases,
pseudoFunctions,
scalarFunctions, sort, limit);
}

public QueryContainer addAgg(String groupId, LeafAgg agg) {
return with(aggs.addAgg(agg));
}
Expand Down Expand Up @@ -465,4 +466,4 @@ public String toString() {
throw new RuntimeException("error rendering", e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESTestCase;
Expand Down Expand Up @@ -111,6 +112,13 @@ public void testNoSort() {
assertEquals(singletonList(fieldSort("_doc").order(SortOrder.ASC)), sourceBuilder.sorts());
}

public void testTrackHits() {
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(new QueryContainer().withTrackHits(), null,
randomIntBetween(1, 10));
assertEquals("Should have tracked hits", Integer.valueOf(SearchContext.TRACK_TOTAL_HITS_ACCURATE),
sourceBuilder.trackTotalHitsUpTo());
}

public void testNoSortIfAgg() {
QueryContainer container = new QueryContainer()
.addGroups(singletonList(new GroupByValue("group_id", "group_column")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,4 +678,54 @@ public void testTopHitsAggregationWithTwoArgs() {
"{\"date\":{\"order\":\"desc\",\"missing\":\"_last\",\"unmapped_type\":\"date\"}}]}}}}}"));
}
}


public void testGlobalCountInImplicitGroupByForcesTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(*) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertTrue("Should be tracking hits", eqe.queryContainer().shouldTrackHits());
}

public void testGlobalCountAllInImplicitGroupByForcesTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(ALL *) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertTrue("Should be tracking hits", eqe.queryContainer().shouldTrackHits());
}

public void testGlobalCountInSpecificGroupByDoesNotForceTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(*) FROM test GROUP BY int");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}

public void testFieldAllCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(ALL int) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}

public void testFieldCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(int) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}

public void testDistinctCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(DISTINCT int) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}

public void testNoCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT int FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}
}

0 comments on commit 5e88001

Please sign in to comment.