Skip to content

Commit

Permalink
Add support for GROUP BY ALL aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Nov 8, 2024
1 parent 1135d71 commit 0b7ccc1
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ querySpecification
;

groupBy
: setQuantifier? groupingElement (',' groupingElement)*
: ASTERISK | (setQuantifier? groupingElement (',' groupingElement)*)
;

groupingElement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4446,6 +4446,32 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope,
ImmutableList.Builder<Expression> groupingExpressions = ImmutableList.builder();

checkGroupingSetsCount(node.getGroupBy().get());
if (node.getGroupBy().get().isAsterisk()) {
// Analyze non-aggregation outputs for GROUP BY *
for (Expression column : outputExpressions) {
if (column instanceof FunctionCall functionCall) {
ResolvedFunction function = getResolvedFunction(functionCall);
if (function.functionKind() == AGGREGATE) {
continue;
}
}
else {
verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause");
analyzeExpression(column, scope);
}

ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column));
if (field != null) {
sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId())));
}
else {
analysis.recordSubqueries(node, analyzeExpression(column, scope));
complexExpressions.add(column);
}

groupingExpressions.add(column);
}
}
for (GroupingElement groupingElement : node.getGroupBy().get().getGroupingElements()) {
if (groupingElement instanceof SimpleGroupBy) {
for (Expression column : groupingElement.getExpressions()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3163,6 +3163,8 @@ public void testGroupBy()
{
// TODO: validate output
analyze("SELECT a, SUM(b) FROM t1 GROUP BY a");
analyze("SELECT a, SUM(b) FROM t1 GROUP BY *");
analyze("SELECT a as x, SUM(b) FROM t1 GROUP BY *");
}

@Test
Expand Down
125 changes: 125 additions & 0 deletions core/trino-main/src/test/java/io/trino/sql/query/TestGroupBy.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,129 @@ public void testGroupByRepeatedOrdinals()
"SELECT null GROUP BY 1, 1"))
.matches("VALUES null");
}

@Test
void testGroupByAll()
{
assertThat(assertions.query(
"""
SELECT *
FROM (VALUES 1) t(a)
GROUP BY *
"""))
.matches("VALUES 1");

assertThat(assertions.query(
"""
SELECT *
FROM (VALUES 1, 2) t(a)
GROUP BY *
"""))
.matches("VALUES 1, 2");

assertThat(assertions.query(
"""
SELECT sum(a)
FROM (VALUES (1), (2)) t(a)
GROUP BY *
"""))
.matches("VALUES BIGINT '3'");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY *
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a AS new_a, sum(b) AS sum_b
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY *
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a + 1, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY *
"""))
.matches("VALUES (2, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT abs(a), sum(b)
FROM (VALUES (-1, 10), (-1, 20)) t(a, b)
GROUP BY *
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT sum(b), a
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY *
"""))
.matches("VALUES (BIGINT '30', 1)");

assertThat(assertions.query(
"""
SELECT sum(a)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY *
"""))
.matches("VALUES (BIGINT '2')");

assertThat(assertions.query(
"""
SELECT a, count(*)
FROM (VALUES 1, 2, 2) t(a)
GROUP BY ALL a
"""))
.matches("VALUES (1, BIGINT '1'), (2, BIGINT '2')");

assertThat(assertions.query(
"""
SELECT a, b, count(*)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY ALL a, b
"""))
.matches("VALUES (1, 10, BIGINT '1'), (1, 20, BIGINT '1')");

// empty grouping set
assertThat(assertions.query(
"""
SELECT count(*)
FROM (VALUES 1, 2, 3) t(a)
GROUP BY ALL ()
"""))
.matches("VALUES BIGINT '3'");

// grouping element list doesn't specify all target column names
assertThat(assertions.query("""
SELECT a, b, count(*)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY ALL a
"""))
.failure().hasMessage("line 1:11: 'b' must be an aggregate expression or appear in GROUP BY clause");

// GROUP BY without set quantifier should fail
assertThat(assertions.query("""
SELECT a, count(*)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY
"""))
.failure().hasMessage("line 4:1: mismatched input '<EOF>'. Expecting: '(', '*', 'ALL', 'CUBE', 'DISTINCT', 'GROUPING', 'ROLLUP', <expression>");

// GROUP BY DISTINCT must have grouping element list
assertThat(assertions.query("""
SELECT a, count(*)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY DISTINCT
"""))
.failure().hasMessage("line 4:1: mismatched input '<EOF>'. Expecting: '(', 'CUBE', 'GROUPING', 'ROLLUP', <expression>");
}
}
19 changes: 18 additions & 1 deletion core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import io.trino.sql.tree.GrantObject;
import io.trino.sql.tree.GrantRoles;
import io.trino.sql.tree.GrantorSpecification;
import io.trino.sql.tree.GroupBy;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.IfStatement;
import io.trino.sql.tree.Insert;
Expand Down Expand Up @@ -697,7 +698,7 @@ protected Void visitQuerySpecification(QuerySpecification node, Integer indent)
append(indent, "WHERE " + formatExpression(where)).append('\n'));

node.getGroupBy().ifPresent(groupBy ->
append(indent, "GROUP BY " + (groupBy.isDistinct() ? " DISTINCT " : "") + formatGroupBy(groupBy.getGroupingElements())).append('\n'));
append(indent, "GROUP BY " + formatGroupByType(groupBy.getType()) + formatGroupBy(groupBy.getGroupingElements())).append('\n'));

node.getHaving().ifPresent(having -> append(indent, "HAVING " + formatExpression(having))
.append('\n'));
Expand All @@ -715,6 +716,22 @@ protected Void visitQuerySpecification(QuerySpecification node, Integer indent)
return null;
}

private static String formatGroupByType(Optional<GroupBy.Type> type)
{
if (!type.isPresent()) {
return "";
}
switch (type.get()) {
case DISTINCT:
return " DISTINCT ";
case ALL:
return " ALL ";
case ASTERISK:
return " * ";
}
throw new UnsupportedOperationException("unknown group by type: " + type);
}

@Override
protected Void visitOrderBy(OrderBy node, Integer indent)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ public Node visitQuerySpecification(SqlBaseParser.QuerySpecificationContext cont
@Override
public Node visitGroupBy(SqlBaseParser.GroupByContext context)
{
return new GroupBy(getLocation(context), isDistinct(context.setQuantifier()), visit(context.groupingElement(), GroupingElement.class));
return new GroupBy(getLocation(context), setQuantifier(context), visit(context.groupingElement(), GroupingElement.class));
}

@Override
Expand Down Expand Up @@ -4070,6 +4070,23 @@ private static boolean isDistinct(SqlBaseParser.SetQuantifierContext setQuantifi
return setQuantifier != null && setQuantifier.DISTINCT() != null;
}

private static Optional<GroupBy.Type> setQuantifier(SqlBaseParser.GroupByContext groupBy)
{
if (groupBy.ASTERISK() != null) {
return Optional.of(GroupBy.Type.ASTERISK);
}
if (groupBy.setQuantifier() == null) {
return Optional.empty();
}
if (groupBy.setQuantifier().DISTINCT() != null) {
return Optional.of(GroupBy.Type.DISTINCT);
}
if (groupBy.setQuantifier().ALL() != null) {
return Optional.of(GroupBy.Type.ALL);
}
throw new UnsupportedOperationException("Unexpected group by context: " + groupBy);
}

private static boolean isHexDigit(char c)
{
return ((c >= '0') && (c <= '9')) ||
Expand Down
48 changes: 36 additions & 12 deletions core/trino-parser/src/main/java/io/trino/sql/tree/GroupBy.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,61 @@
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.sql.tree.GroupBy.Type.ASTERISK;
import static io.trino.sql.tree.GroupBy.Type.DISTINCT;
import static java.util.Objects.requireNonNull;

public class GroupBy
extends Node
{
private final boolean isDistinct;
public enum Type
{
DISTINCT,
ALL,
ASTERISK,
/**/
}

private final Optional<Type> type;
private final List<GroupingElement> groupingElements;

@Deprecated
public GroupBy(boolean isDistinct, List<GroupingElement> groupingElements)
public GroupBy(Optional<Type> type, List<GroupingElement> groupingElements)
{
this(Optional.empty(), isDistinct, groupingElements);
this(Optional.empty(), type, groupingElements);
}

private GroupBy(Optional<NodeLocation> location, boolean isDistinct, List<GroupingElement> groupingElements)
private GroupBy(Optional<NodeLocation> location, Optional<Type> type, List<GroupingElement> groupingElements)
{
super(location);
this.isDistinct = isDistinct;
this.type = requireNonNull(type, "type is null");
this.groupingElements = ImmutableList.copyOf(requireNonNull(groupingElements));
if (type.isPresent() && type.get() != ASTERISK) {
checkArgument(!groupingElements.isEmpty(), "groupingElements must not be empty");
}
}

public Optional<Type> getType()
{
return type;
}

public GroupBy(NodeLocation location, boolean isDistinct, List<GroupingElement> groupingElements)
public GroupBy(NodeLocation location,Optional<Type> type, List<GroupingElement> groupingElements)
{
super(location);
this.isDistinct = isDistinct;
this.type = requireNonNull(type, "type is null");
this.groupingElements = ImmutableList.copyOf(groupingElements);
}

public boolean isDistinct()
{
return isDistinct;
return type.isPresent() && type.get() == DISTINCT;
}

public boolean isAsterisk()
{
return type.isPresent() && type.get() == ASTERISK;
}

public List<GroupingElement> getGroupingElements()
Expand Down Expand Up @@ -80,21 +104,21 @@ public boolean equals(Object o)
return false;
}
GroupBy groupBy = (GroupBy) o;
return isDistinct == groupBy.isDistinct &&
return Objects.equals(type, groupBy.type) &&
Objects.equals(groupingElements, groupBy.groupingElements);
}

@Override
public int hashCode()
{
return Objects.hash(isDistinct, groupingElements);
return Objects.hash(type, groupingElements);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("isDistinct", isDistinct)
.add("type", type.orElse(null))
.add("groupingElements", groupingElements)
.toString();
}
Expand All @@ -106,6 +130,6 @@ public boolean shallowEquals(Node other)
return false;
}

return isDistinct == ((GroupBy) other).isDistinct;
return type.equals(((GroupBy) other).type);
}
}
Loading

0 comments on commit 0b7ccc1

Please sign in to comment.