Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: Fix ORDER BY on aggregates and GROUPed BY fields #51894

Merged
merged 10 commits into from
Feb 12, 2020
27 changes: 24 additions & 3 deletions x-pack/plugin/sql/qa/src/main/resources/agg-ordering.sql-spec
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ aggNotSpecifiedInTheAggregateAndGroupWithHavingWithLimitAndDirection
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary) ASC, c DESC LIMIT 5;

groupAndAggNotSpecifiedInTheAggregateWithHaving
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender, MAX(salary);
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender NULLS FIRST, MAX(salary);

multipleAggsThatGetRewrittenWithAliasOnAMediumGroupBy
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages ORDER BY max;
Expand Down Expand Up @@ -136,5 +136,26 @@ SELECT gender AS g, first_name AS f, last_name AS l FROM test_emp GROUP BY f, ge
multipleGroupingsAndOrderingByGroups_8
SELECT gender AS g, first_name, last_name FROM test_emp GROUP BY g, last_name, first_name ORDER BY gender ASC, first_name DESC, last_name ASC;

multipleGroupingsAndOrderingByGroupsWithFunctions
SELECT first_name f, last_name l, gender g, CONCAT(first_name, last_name) c FROM test_emp GROUP BY gender, l, f, c ORDER BY gender, c DESC, first_name, last_name ASC;
multipleGroupingsAndOrderingByGroupsAndAggs_1
SELECT gender, MIN(salary) AS min, COUNT(*) AS c, MAX(salary) AS max FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender ASC NULLS FIRST, MAX(salary) DESC;

multipleGroupingsAndOrderingByGroupsAndAggs_2
SELECT gender, MIN(salary) AS min, COUNT(*) AS c, MAX(salary) AS max FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender DESC NULLS LAST, MAX(salary) ASC;

multipleGroupingsAndOrderingByGroupsWithFunctions_1
SELECT first_name f, last_name l, gender g, CONCAT(first_name, last_name) c FROM test_emp GROUP BY gender, l, f, c ORDER BY gender NULLS FIRST, c DESC, first_name, last_name ASC;

multipleGroupingsAndOrderingByGroupsWithFunctions_2
SELECT first_name f, last_name l, gender g, CONCAT(first_name, last_name) c FROM test_emp GROUP BY gender, l, f, c ORDER BY c DESC, gender DESC NULLS LAST, first_name, last_name ASC;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_1
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 1 NULLS FIRST, 2, 3;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_2
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 1 DESC NULLS LAST, 2, 3;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_3
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 2, 1 NULLS FIRST, 3;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_4
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 3 DESC, 1 NULLS FIRST, 2;
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.elasticsearch.xpack.sql.querydsl.container.GlobalCountRef;
import org.elasticsearch.xpack.sql.querydsl.container.GroupByRef;
import org.elasticsearch.xpack.sql.querydsl.container.GroupByRef.Property;
import org.elasticsearch.xpack.sql.querydsl.container.GroupingFunctionSort;
import org.elasticsearch.xpack.sql.querydsl.container.MetricAggRef;
import org.elasticsearch.xpack.sql.querydsl.container.PivotColumnRef;
import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer;
Expand Down Expand Up @@ -682,37 +683,34 @@ protected PhysicalPlan rule(OrderExec plan) {

// TODO: might need to validate whether the target field or group actually exist
if (group != null && group != Aggs.IMPLICIT_GROUP_KEY) {
// check whether the lookup matches a group
if (group.id().equals(lookup)) {
qContainer = qContainer.updateGroup(group.with(direction));
}
// else it's a leafAgg
else {
qContainer = qContainer.updateGroup(group.with(direction));
}
qContainer = qContainer.updateGroup(group.with(direction));
}

// field
if (orderExpression instanceof FieldAttribute) {
qContainer = qContainer.addSort(new AttributeSort((FieldAttribute) orderExpression, direction, missing));
}
// scalar functions typically require script ordering
else if (orderExpression instanceof ScalarFunction) {
ScalarFunction sf = (ScalarFunction) orderExpression;
// nope, use scripted sorting
qContainer = qContainer.addSort(new ScriptSort(Expressions.id(sf), sf.asScript(), direction, missing));
}
// histogram
else if (orderExpression instanceof Histogram) {
qContainer = qContainer.addSort(new GroupingFunctionSort(Expressions.id(orderExpression), direction, missing));
}
// score
else if (orderExpression instanceof Score) {
qContainer = qContainer.addSort(new ScoreSort(Expressions.id(orderExpression), direction, missing));
}
// agg function
else if (orderExpression instanceof AggregateFunction) {
qContainer = qContainer.addSort(new AggregateSort((AggregateFunction) orderExpression, direction, missing));
}
// unknown
else {
// scalar functions typically require script ordering
if (orderExpression instanceof ScalarFunction) {
ScalarFunction sf = (ScalarFunction) orderExpression;
// nope, use scripted sorting
qContainer = qContainer.addSort(new ScriptSort(sf.asScript(), direction, missing));
}
// score
else if (orderExpression instanceof Score) {
qContainer = qContainer.addSort(new ScoreSort(direction, missing));
}
// field
else if (orderExpression instanceof FieldAttribute) {
qContainer = qContainer.addSort(new AttributeSort((FieldAttribute) orderExpression, direction, missing));
}
// agg function
else if (orderExpression instanceof AggregateFunction) {
qContainer = qContainer.addSort(new AggregateSort((AggregateFunction) orderExpression, direction, missing));
} else {
// unknown
throw new SqlIllegalArgumentException("unsupported sorting expression {}", orderExpression);
}
throw new SqlIllegalArgumentException("unsupported sorting expression {}", orderExpression);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

package org.elasticsearch.xpack.sql.querydsl.container;

import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;

import java.util.Objects;
Expand All @@ -23,6 +24,11 @@ public AggregateFunction agg() {
return agg;
}

@Override
public String id() {
return Expressions.id(agg);
}

@Override
public int hashCode() {
return Objects.hash(agg, direction(), missing());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.sql.querydsl.container;

import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.Expressions;

import java.util.Objects;

Expand All @@ -22,6 +23,11 @@ public Attribute attribute() {
return attribute;
}

@Override
public String id() {
return Expressions.id(attribute);
}

@Override
public int hashCode() {
return Objects.hash(attribute, direction(), missing());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.container;

import java.util.Objects;

public class GroupingFunctionSort extends Sort {

private final String id;

public GroupingFunctionSort(String id, Direction direction, Missing missing) {
super(direction, missing);
this.id = id;
}

@Override
public String id() {
return id;
}

@Override
public int hashCode() {
return Objects.hash(direction(), missing(), id);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (obj == null || getClass() != obj.getClass()) {
return false;
}

GroupingFunctionSort other = (GroupingFunctionSort) obj;
return Objects.equals(direction(), other.direction())
&& Objects.equals(missing(), other.missing())
&& Objects.equals(id(), other.id());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.ConstantInput;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
Expand Down Expand Up @@ -134,45 +133,40 @@ public List<Tuple<Integer, Comparator>> sortingColumns() {
return emptyList();
}

List<Tuple<Integer, Comparator>> sortingColumns = new ArrayList<>(sort.size());

boolean aggSort = false;
for (Sort s : sort) {
Tuple<Integer, Comparator> tuple = new Tuple<>(Integer.valueOf(-1), null);

if (s instanceof AggregateSort) {
AggregateSort as = (AggregateSort) s;
// find the relevant column of each aggregate function
AggregateFunction af = as.agg();

aggSort = true;
int atIndex = -1;
String id = Expressions.id(af);

for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, String> field = fields.get(i);
if (field.v2().equals(id)) {
atIndex = i;
break;
}
}
if (atIndex == -1) {
throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", s);
}
// assemble a comparator for it
Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder();
comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp);

tuple = new Tuple<>(Integer.valueOf(atIndex), comp);
customSort = Boolean.TRUE;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not breaking early, if there's the first AggregateSort found?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot break after the first AggregateSort. Maybe we could break after the last AggregateSort.
If we have:

SELECT f1, f2, f3, MAX(f4) as max, MIN(f5) as min
FROM test
GROUP BY f1, f2, f3
ORDER BY f1, max, f2, min, f3

we cannot break after max, we could break after min.

I'd rather leave the fix as is and introduce this optimisation in a separate PR where it's properly tested that it works.
(Needs some carefully chosen data set to test this ordering case)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. That's a simple loop that, when finds an AggregateSort, will set customSort to TRUE. It doesn't really matter what's after in the list of sorts because it doesn't change the value of customSort.
Also, I meant breaking from inside the loop not from the method...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, misunderstood you there, sure we should break once the 1st is found.

break;
}
sortingColumns.add(tuple);
}


// If no custom sort is used break early
if (customSort == null) {
customSort = Boolean.valueOf(aggSort);
customSort = Boolean.FALSE;
return emptyList();
}

return aggSort ? sortingColumns : emptyList();
List<Tuple<Integer, Comparator>> sortingColumns = new ArrayList<>(sort.size());
for (Sort s: sort) {
int atIndex = -1;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, String> field = fields.get(i);
if (field.v2().equals(s.id())) {
atIndex = i;
break;
}
}
if (atIndex == -1) {
throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", s);
}
// assemble a comparator for it
Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder();
comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp);

sortingColumns.add(new Tuple<>(Integer.valueOf(atIndex), comp));
}

return sortingColumns;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import java.util.Objects;

public class ScoreSort extends Sort {
public ScoreSort(Direction direction, Missing missing) {

private final String id;

public ScoreSort(String id, Direction direction, Missing missing) {
super(direction, missing);
this.id = id;
}

@Override
public String id() {
return id;
}

@Override
public int hashCode() {
return Objects.hash(direction(), missing());
return Objects.hash(direction(), missing(), id());
}

@Override
Expand All @@ -29,6 +38,7 @@ public boolean equals(Object obj) {

ScriptSort other = (ScriptSort) obj;
return Objects.equals(direction(), other.direction())
&& Objects.equals(missing(), other.missing());
&& Objects.equals(missing(), other.missing())
&& Objects.equals(id(), other.id());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,27 @@

public class ScriptSort extends Sort {

private final String id;
private final ScriptTemplate script;

public ScriptSort(ScriptTemplate script, Direction direction, Missing missing) {
public ScriptSort(String id, ScriptTemplate script, Direction direction, Missing missing) {
super(direction, missing);
this.id = id;
this.script = Scripts.nullSafeSort(script);
}

@Override
public String id() {
return id;
}

public ScriptTemplate script() {
return script;
}

@Override
public int hashCode() {
return Objects.hash(direction(), missing(), script);
return Objects.hash(direction(), missing(), id(), script());
}

@Override
Expand All @@ -37,10 +44,11 @@ public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}

ScriptSort other = (ScriptSort) obj;
return Objects.equals(direction(), other.direction())
&& Objects.equals(missing(), other.missing())
&& Objects.equals(script, other.script);
&& Objects.equals(id(), other.id())
&& Objects.equals(script(), other.script());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ protected Sort(Direction direction, Missing nulls) {
this.missing = nulls;
}

public abstract String id();

public Direction direction() {
return direction;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void testSelectScoreForcesTrackingScore() {

public void testSortScoreSpecified() {
QueryContainer container = new QueryContainer()
.addSort(new ScoreSort(Direction.DESC, null));
.addSort(new ScoreSort("id", Direction.DESC, null));
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(scoreSort()), sourceBuilder.sorts());
}
Expand Down Expand Up @@ -137,4 +137,4 @@ public void testNoSortIfAgg() {
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertNull(sourceBuilder.sorts());
}
}
}