Skip to content

Commit

Permalink
Prevent aggregation pushdown for textual types for some connectors
Browse files Browse the repository at this point in the history
Some databases are case-insensitive (MySQL, SQL Server) while others
sort textual types differently compared to Trino (PostgreSQL). For such
databases pushdown of aggregation functions when the grouping set
includes a textual type can lead to incorrect results. So we prevent
aggregation pushdown for such cases.
We also prevent pushdown for functions whose results depend on sort
order (min/max) when the input is a textual type.
  • Loading branch information
hashhar committed Aug 2, 2021
1 parent c0bb821 commit 6473c84
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,25 @@ protected WriteMapping legacyToWriteMapping(@SuppressWarnings("unused") Connecto
throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName());
}

protected static boolean preventTextualTypeAggregationPushdown(List<List<ColumnHandle>> groupingSets)
{
// Remote database can be case insensitive or sorts textual types differently than Trino.
// In such cases we should not pushdown aggregations if the grouping set contains a textual type.
if (!groupingSets.isEmpty()) {
for (List<ColumnHandle> groupingSet : groupingSets) {
boolean hasCaseSensitiveGroupingSet = groupingSet.stream()
.map(columnHandle -> ((JdbcColumnHandle) columnHandle).getColumnType())
// this may catch more cases than required (e.g. MONEY in Postgres) but doesn't affect correctness
.anyMatch(type -> type instanceof VarcharType || type instanceof CharType);
if (hasCaseSensitiveGroupingSet) {
return false;
}
}
}

return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.CharType;
import io.trino.spi.type.VarcharType;

import java.util.Optional;
import java.util.Set;
Expand All @@ -40,6 +42,13 @@ public class ImplementMinMax
{
private static final Capture<Variable> INPUT = newCapture();

private final boolean isRemoteCollationSensitive;

public ImplementMinMax(boolean isRemoteCollationSensitive)
{
this.isRemoteCollationSensitive = isRemoteCollationSensitive;
}

@Override
public Pattern<AggregateFunction> getPattern()
{
Expand All @@ -55,6 +64,11 @@ public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Cap
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType()));

// Remote database is case insensitive or sorts values differently from Trino
if (!isRemoteCollationSensitive && (columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType)) {
return Optional.empty();
}

return Optional.of(new JdbcExpression(
format("%s(%s)", aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName())),
columnHandle.getJdbcTypeHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public void testAggregationPushdown()
// GROUP BY above TopN
assertConditionallyPushedDown(
getSession(),
"SELECT clerk, sum(totalprice) FROM (SELECT clerk, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY clerk",
"SELECT custkey, sum(totalprice) FROM (SELECT custkey, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY custkey",
hasBehavior(SUPPORTS_TOPN_PUSHDOWN),
node(TopNNode.class, anyTree(node(TableScanNode.class))));
// GROUP BY with JOIN
Expand All @@ -211,17 +211,100 @@ public void testAggregationPushdown()
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY),
node(FilterNode.class, node(TableScanNode.class)));
// aggregation on varchar column
assertThat(query("SELECT min(name) FROM nation")).isFullyPushedDown();
assertThat(query("SELECT count(name) FROM nation")).isFullyPushedDown();
// aggregation on varchar column with GROUPING
assertThat(query("SELECT nationkey, min(name) FROM nation GROUP BY nationkey")).isFullyPushedDown();
assertThat(query("SELECT nationkey, count(name) FROM nation GROUP BY nationkey")).isFullyPushedDown();
// aggregation on varchar column with WHERE
assertConditionallyPushedDown(
getSession(),
"SELECT min(name) FROM nation WHERE name = 'ARGENTINA'",
"SELECT count(name) FROM nation WHERE name = 'ARGENTINA'",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY),
node(FilterNode.class, node(TableScanNode.class)));
}

@Test
public void testCaseSensitiveAggregationPushdown()
{
if (!hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN)) {
// Covered by testAggregationPushdown
return;
}

boolean expectAggregationPushdown = hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY);
PlanMatchPattern aggregationOverTableScan = node(AggregationNode.class, node(TableScanNode.class));
PlanMatchPattern groupingAggregationOverTableScan = node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)));
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"test_case_sensitive_aggregation_pushdown",
"(a_string varchar(1), a_char char(1), a_bigint bigint)",
ImmutableList.of(
"'A', 'A', 1",
"'B', 'B', 2",
"'a', 'a', 3",
"'b', 'b', 4"))) {
// case-sensitive functions prevent pushdown
assertConditionallyPushedDown(
getSession(),
"SELECT max(a_string), min(a_string), max(a_char), min(a_char) FROM " + table.getName(),
expectAggregationPushdown,
aggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES ('b', 'A', 'b', 'A')");
// distinct over case-sensitive column prevents pushdown
assertConditionallyPushedDown(
getSession(),
"SELECT distinct a_string FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES 'A', 'B', 'a', 'b'");
assertConditionallyPushedDown(
getSession(),
"SELECT distinct a_char FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES 'A', 'B', 'a', 'b'");
// case-sensitive grouping sets prevent pushdown
assertConditionallyPushedDown(getSession(),
"SELECT a_string, count(*) FROM " + table.getName() + " GROUP BY a_string",
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES ('A', BIGINT '1'), ('a', BIGINT '1'), ('b', BIGINT '1'), ('B', BIGINT '1')");
assertConditionallyPushedDown(getSession(),
"SELECT a_char, count(*) FROM " + table.getName() + " GROUP BY a_char",
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES ('A', BIGINT '1'), ('B', BIGINT '1'), ('a', BIGINT '1'), ('b', BIGINT '1')");

// case-insensitive functions can still be pushed down as long as grouping sets are not case-sensitive
assertThat(query("SELECT count(a_string), count(a_char) FROM " + table.getName())).isFullyPushedDown();
assertThat(query("SELECT count(a_string), count(a_char) FROM " + table.getName() + " GROUP BY a_bigint")).isFullyPushedDown();

// DISTINCT over case-sensitive columns prevents pushdown
assertConditionallyPushedDown(getSession(),
"SELECT count(DISTINCT a_string) FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES BIGINT '4'");
assertConditionallyPushedDown(getSession(),
"SELECT count(DISTINCT a_char) FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES BIGINT '4'");

// TODO: multiple count(DISTINCT expr) cannot be pushed down until https://github.com/trinodb/trino/pull/8562 gets merged
assertThat(query("SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName()))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
assertThat(query("SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName()))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
}
}

@Test
public void testAggregationWithUnsupportedResultType()
{
Expand Down Expand Up @@ -265,7 +348,11 @@ public void testDistinctAggregationPushdown()
// distinct aggregation with GROUP BY
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
// distinct aggregation with varchar
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT comment) FROM nation")).isFullyPushedDown();
assertConditionallyPushedDown(
withMarkDistinct,
"SELECT count(DISTINCT comment) FROM nation",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))));
// two distinct aggregations
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation"))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
Expand All @@ -281,7 +368,11 @@ public void testDistinctAggregationPushdown()
// distinct aggregation with GROUP BY
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
// distinct aggregation with varchar
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT comment) FROM nation")).isFullyPushedDown();
assertConditionallyPushedDown(
withoutMarkDistinct,
"SELECT count(DISTINCT comment) FROM nation",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))));
// two distinct aggregations
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation"))
.isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class);
Expand Down Expand Up @@ -583,7 +674,7 @@ public void testLimitPushdown()
aggregationOverTableScan);
assertConditionallyPushedDown(
getSession(),
"SELECT regionkey, max(name) FROM nation GROUP BY regionkey LIMIT 5",
"SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey LIMIT 5",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN),
aggregationOverTableScan);

Expand Down Expand Up @@ -1042,18 +1133,18 @@ public void testJoinPushdown()
}
}

private void assertConditionallyPushedDown(
private QueryAssert assertConditionallyPushedDown(
Session session,
@Language("SQL") String query,
boolean condition,
PlanMatchPattern otherwiseExpected)
{
QueryAssert queryAssert = assertThat(query(session, query));
if (condition) {
queryAssert.isFullyPushedDown();
return queryAssert.isFullyPushedDown();
}
else {
queryAssert.isNotFullyPushedDown(otherwiseExpected);
return queryAssert.isNotFullyPushedDown(otherwiseExpected);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public ClickHouseClient(
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementMinMax(false)) // TODO: Revisit once https://github.com/trinodb/trino/issues/7100 is resolved
.add(new ImplementSum(ClickHouseClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
Expand All @@ -159,6 +159,13 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// TODO: Remove override once https://github.com/trinodb/trino/issues/7100 is resolved. Currently pushdown for textual types is not tested and may lead to incorrect results.
return preventTextualTypeAggregationPushdown(groupingSets);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.DECIMAL, Optional.of("Decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, T
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementMinMax(false))
.add(new ImplementSum(MySqlClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
Expand All @@ -176,6 +176,13 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// Remote database can be case insensitive.
return preventTextualTypeAggregationPushdown(groupingSets);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ public PostgreSqlClient(
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementMinMax(false))
.add(new ImplementSum(PostgreSqlClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
Expand Down Expand Up @@ -695,6 +695,13 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// Postgres sorts textual types differently compared to Trino so we cannot safely pushdown any aggregations which take a text type as an input or as part of grouping set
return preventTextualTypeAggregationPushdown(groupingSets);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public SqlServerClient(BaseJdbcConfig config, SqlServerConfig sqlServerConfig, C
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementMinMax(false))
.add(new ImplementSum(SqlServerClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
Expand Down Expand Up @@ -424,6 +424,13 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// Remote database can be case insensitive.
return preventTextualTypeAggregationPushdown(groupingSets);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down

0 comments on commit 6473c84

Please sign in to comment.