Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect results for aggregation functions on case-sensitive types #8551

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,25 @@ protected WriteMapping legacyToWriteMapping(@SuppressWarnings("unused") Connecto
throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName());
}

protected static boolean preventTextualTypeAggregationPushdown(List<List<ColumnHandle>> groupingSets)
{
// Remote database can be case insensitive or sorts textual types differently than Trino.
// In such cases we should not pushdown aggregations if the grouping set contains a textual type.
if (!groupingSets.isEmpty()) {
for (List<ColumnHandle> groupingSet : groupingSets) {
boolean hasCaseSensitiveGroupingSet = groupingSet.stream()
.map(columnHandle -> ((JdbcColumnHandle) columnHandle).getColumnType())
// this may catch more cases than required (e.g. MONEY in Postgres) but doesn't affect correctness
.anyMatch(type -> type instanceof VarcharType || type instanceof CharType);
if (hasCaseSensitiveGroupingSet) {
return false;
}
}
}

return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<List<ColumnHandle>> groupingSets)
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
return delegate.supportsAggregationPushdown(session, table, groupingSets);
return delegate.supportsAggregationPushdown(session, table, aggregates, assignments, groupingSets);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");

if (!jdbcClient.supportsAggregationPushdown(session, handle, groupingSets)) {
if (!jdbcClient.supportsAggregationPushdown(session, handle, aggregates, assignments, groupingSets)) {
// JDBC client implementation prevents pushdown for the given table
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<List<ColumnHandle>> groupingSets)
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
return delegate().supportsAggregationPushdown(session, table, groupingSets);
return delegate().supportsAggregationPushdown(session, table, aggregates, assignments, groupingSets);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ default boolean schemaExists(ConnectorSession session, String schema)

WriteMapping toWriteMapping(ConnectorSession session, Type type);

default boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<List<ColumnHandle>> groupingSets)
default boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.CharType;
import io.trino.spi.type.VarcharType;

import java.util.Optional;
import java.util.Set;
Expand All @@ -40,6 +42,13 @@ public class ImplementMinMax
{
private static final Capture<Variable> INPUT = newCapture();

private final boolean isRemoteCollationSensitive;

public ImplementMinMax(boolean isRemoteCollationSensitive)
{
this.isRemoteCollationSensitive = isRemoteCollationSensitive;
}

@Override
public Pattern<AggregateFunction> getPattern()
{
Expand All @@ -55,6 +64,11 @@ public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Cap
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType()));

// Remote database is case insensitive or sorts values differently from Trino
if (!isRemoteCollationSensitive && (columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType)) {
return Optional.empty();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's based on a wrong type.

For example, PostgreSQL's money and enum types are mapped to Trino varchar, while both will have different sorting properties.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the rewrite rule is generic I think the only solution is to allow connectors to pass a list of jdbcTypeName through the constructor for types which should not be pushed down?

Or maybe add a static function to JdbcClient called isCollationSensitive(JdbcColumnHandle) and let connectors define their own impls? This method already exists in the PostgreSQL client btw and may be useful for others where we can pass explicit collations (MySQL once we drop 5.x).

Copy link
Member

@findepi findepi Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, i think we already use "is mapped to char or varchar" as a way to determine if it's potentially case insensitive (or collation-sensitive).
it may catch too much, but shouldn't catch too little, so it's fine.

leave as is

}

return Optional.of(new JdbcExpression(
format("%s(%s)", aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName())),
columnHandle.getJdbcTypeHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<List<ColumnHandle>> groupingSets)
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
return delegate().supportsAggregationPushdown(session, table, groupingSets);
return delegate().supportsAggregationPushdown(session, table, aggregates, assignments, groupingSets);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public void testAggregationPushdown()
// GROUP BY above TopN
assertConditionallyPushedDown(
getSession(),
"SELECT clerk, sum(totalprice) FROM (SELECT clerk, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY clerk",
"SELECT custkey, sum(totalprice) FROM (SELECT custkey, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY custkey",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clerk is varchar and part of grouping set so it prevents pushdown after the change. Changed to use a numeric column since we just want to test GROUP BY + TOPN.

hasBehavior(SUPPORTS_TOPN_PUSHDOWN),
node(TopNNode.class, anyTree(node(TableScanNode.class))));
// GROUP BY with JOIN
Expand All @@ -211,17 +211,100 @@ public void testAggregationPushdown()
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY),
node(FilterNode.class, node(TableScanNode.class)));
// aggregation on varchar column
assertThat(query("SELECT min(name) FROM nation")).isFullyPushedDown();
assertThat(query("SELECT count(name) FROM nation")).isFullyPushedDown();
// aggregation on varchar column with GROUPING
assertThat(query("SELECT nationkey, min(name) FROM nation GROUP BY nationkey")).isFullyPushedDown();
assertThat(query("SELECT nationkey, count(name) FROM nation GROUP BY nationkey")).isFullyPushedDown();
// aggregation on varchar column with WHERE
assertConditionallyPushedDown(
getSession(),
"SELECT min(name) FROM nation WHERE name = 'ARGENTINA'",
"SELECT count(name) FROM nation WHERE name = 'ARGENTINA'",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY),
node(FilterNode.class, node(TableScanNode.class)));
}

@Test
public void testCaseSensitiveAggregationPushdown()
{
if (!hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN)) {
// Covered by testAggregationPushdown
return;
}

boolean expectAggregationPushdown = hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY);
PlanMatchPattern aggregationOverTableScan = node(AggregationNode.class, node(TableScanNode.class));
PlanMatchPattern groupingAggregationOverTableScan = node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)));
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"test_case_sensitive_aggregation_pushdown",
"(a_string varchar(1), a_char char(1), a_bigint bigint)",
ImmutableList.of(
"'A', 'A', 1",
"'B', 'B', 2",
"'a', 'a', 3",
"'b', 'b', 4"))) {
// case-sensitive functions prevent pushdown
assertConditionallyPushedDown(
getSession(),
"SELECT max(a_string), min(a_string), max(a_char), min(a_char) FROM " + table.getName(),
expectAggregationPushdown,
aggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES ('b', 'A', 'b', 'A')");
// distinct over case-sensitive column prevents pushdown
assertConditionallyPushedDown(
getSession(),
"SELECT distinct a_string FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES 'A', 'B', 'a', 'b'");
assertConditionallyPushedDown(
getSession(),
"SELECT distinct a_char FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES 'A', 'B', 'a', 'b'");
// case-sensitive grouping sets prevent pushdown
assertConditionallyPushedDown(getSession(),
"SELECT a_string, count(*) FROM " + table.getName() + " GROUP BY a_string",
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES ('A', BIGINT '1'), ('a', BIGINT '1'), ('b', BIGINT '1'), ('B', BIGINT '1')");
assertConditionallyPushedDown(getSession(),
"SELECT a_char, count(*) FROM " + table.getName() + " GROUP BY a_char",
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES ('A', BIGINT '1'), ('B', BIGINT '1'), ('a', BIGINT '1'), ('b', BIGINT '1')");

// case-insensitive functions can still be pushed down as long as grouping sets are not case-sensitive
assertThat(query("SELECT count(a_string), count(a_char) FROM " + table.getName())).isFullyPushedDown();
assertThat(query("SELECT count(a_string), count(a_char) FROM " + table.getName() + " GROUP BY a_bigint")).isFullyPushedDown();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a correctness cases

  • count(DISTINCT a_string)
  • count(DISTINCT a_string), count(DISTINCT a_bigint)` (together)

this could help avoid any regressions in #8562 cc @alexjo2144


// DISTINCT over case-sensitive columns prevents pushdown
assertConditionallyPushedDown(getSession(),
"SELECT count(DISTINCT a_string) FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES BIGINT '4'");
assertConditionallyPushedDown(getSession(),
"SELECT count(DISTINCT a_char) FROM " + table.getName(),
expectAggregationPushdown,
groupingAggregationOverTableScan)
.skippingTypesCheck()
.matches("VALUES BIGINT '4'");

// TODO: multiple count(DISTINCT expr) cannot be pushed down until https://github.com/trinodb/trino/pull/8562 gets merged
assertThat(query("SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName()))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
assertThat(query("SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName()))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
}
}

@Test
public void testAggregationWithUnsupportedResultType()
{
Expand Down Expand Up @@ -265,7 +348,11 @@ public void testDistinctAggregationPushdown()
// distinct aggregation with GROUP BY
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
// distinct aggregation with varchar
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT comment) FROM nation")).isFullyPushedDown();
assertConditionallyPushedDown(
withMarkDistinct,
"SELECT count(DISTINCT comment) FROM nation",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))));
// two distinct aggregations
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation"))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
Expand All @@ -281,7 +368,11 @@ public void testDistinctAggregationPushdown()
// distinct aggregation with GROUP BY
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
// distinct aggregation with varchar
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT comment) FROM nation")).isFullyPushedDown();
assertConditionallyPushedDown(
withoutMarkDistinct,
"SELECT count(DISTINCT comment) FROM nation",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))));
// two distinct aggregations
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation"))
.isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class);
Expand Down Expand Up @@ -583,7 +674,7 @@ public void testLimitPushdown()
aggregationOverTableScan);
assertConditionallyPushedDown(
getSession(),
"SELECT regionkey, max(name) FROM nation GROUP BY regionkey LIMIT 5",
"SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey LIMIT 5",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN),
aggregationOverTableScan);

Expand Down Expand Up @@ -1042,18 +1133,18 @@ public void testJoinPushdown()
}
}

private void assertConditionallyPushedDown(
private QueryAssert assertConditionallyPushedDown(
Session session,
@Language("SQL") String query,
boolean condition,
PlanMatchPattern otherwiseExpected)
{
QueryAssert queryAssert = assertThat(query(session, query));
if (condition) {
queryAssert.isFullyPushedDown();
return queryAssert.isFullyPushedDown();
}
else {
queryAssert.isNotFullyPushedDown(otherwiseExpected);
return queryAssert.isNotFullyPushedDown(otherwiseExpected);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ protected JdbcClient delegate()
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<List<ColumnHandle>> groupingSets)
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// disable aggregation pushdown for any table named no_agg_pushdown
return !"no_aggregation_pushdown".equalsIgnoreCase(table.getRequiredNamedRelation().getRemoteTableName().getTableName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFa
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<List<ColumnHandle>> groupingSets)
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// GROUP BY with GROUPING SETS is not supported
return groupingSets.size() == 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public ClickHouseClient(
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementMinMax(false)) // TODO: Revisit once https://github.com/trinodb/trino/issues/7100 is resolved
.add(new ImplementSum(ClickHouseClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
Expand All @@ -159,6 +159,13 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// TODO: Remove override once https://github.com/trinodb/trino/issues/7100 is resolved. Currently pushdown for textual types is not tested and may lead to incorrect results.
findepi marked this conversation as resolved.
Show resolved Hide resolved
return preventTextualTypeAggregationPushdown(groupingSets);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.DECIMAL, Optional.of("Decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, T
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementMinMax(false))
.add(new ImplementSum(MySqlClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
Expand All @@ -176,6 +176,13 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
// Remote database can be case insensitive.
return preventTextualTypeAggregationPushdown(groupingSets);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Loading