diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 99672dcda0b8c..4fb49bf8f19dd 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -950,6 +950,25 @@ protected WriteMapping legacyToWriteMapping(@SuppressWarnings("unused") Connecto throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } + protected static boolean preventTextualTypeAggregationPushdown(List> groupingSets) + { + // Remote database can be case insensitive or sorts textual types differently than Trino. + // In such cases we should not pushdown aggregations if the grouping set contains a textual type. + if (!groupingSets.isEmpty()) { + for (List groupingSet : groupingSets) { + boolean hasCaseSensitiveGroupingSet = groupingSet.stream() + .map(columnHandle -> ((JdbcColumnHandle) columnHandle).getColumnType()) + // this may catch more cases than required (e.g. MONEY in Postgres) but doesn't affect correctness + .anyMatch(type -> type instanceof VarcharType || type instanceof CharType); + if (hasCaseSensitiveGroupingSet) { + return false; + } + } + } + + return true; + } + @Override public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java index 7614ca4724e57..adfc95a99e998 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java @@ -20,6 +20,8 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; +import io.trino.spi.type.CharType; +import io.trino.spi.type.VarcharType; import java.util.Optional; import java.util.Set; @@ -40,6 +42,13 @@ public class ImplementMinMax { private static final Capture INPUT = newCapture(); + private final boolean isRemoteCollationSensitive; + + public ImplementMinMax(boolean isRemoteCollationSensitive) + { + this.isRemoteCollationSensitive = isRemoteCollationSensitive; + } + @Override public Pattern getPattern() { @@ -55,6 +64,11 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType())); + // Remote database is case insensitive or sorts values differently from Trino + if (!isRemoteCollationSensitive && (columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType)) { + return Optional.empty(); + } + return Optional.of(new JdbcExpression( format("%s(%s)", aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 369b690cf8260..211d34d5b2b6d 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -195,7 +195,7 @@ public void testAggregationPushdown() // GROUP BY above TopN assertConditionallyPushedDown( getSession(), - "SELECT clerk, sum(totalprice) FROM (SELECT clerk, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY clerk", + "SELECT custkey, sum(totalprice) FROM (SELECT custkey, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY custkey", hasBehavior(SUPPORTS_TOPN_PUSHDOWN), node(TopNNode.class, anyTree(node(TableScanNode.class)))); // GROUP BY with JOIN @@ -211,17 +211,100 @@ public void testAggregationPushdown() hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY), node(FilterNode.class, node(TableScanNode.class))); // aggregation on varchar column - assertThat(query("SELECT min(name) FROM nation")).isFullyPushedDown(); + assertThat(query("SELECT count(name) FROM nation")).isFullyPushedDown(); // aggregation on varchar column with GROUPING - assertThat(query("SELECT nationkey, min(name) FROM nation GROUP BY nationkey")).isFullyPushedDown(); + assertThat(query("SELECT nationkey, count(name) FROM nation GROUP BY nationkey")).isFullyPushedDown(); // aggregation on varchar column with WHERE assertConditionallyPushedDown( getSession(), - "SELECT min(name) FROM nation WHERE name = 'ARGENTINA'", + "SELECT count(name) FROM nation WHERE name = 'ARGENTINA'", hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY), node(FilterNode.class, node(TableScanNode.class))); } + @Test + public void testCaseSensitiveAggregationPushdown() + { + if (!hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN)) { + // Covered by testAggregationPushdown + return; + } + + boolean expectAggregationPushdown = hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY); + PlanMatchPattern aggregationOverTableScan = node(AggregationNode.class, node(TableScanNode.class)); + PlanMatchPattern groupingAggregationOverTableScan = node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))); + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_case_sensitive_aggregation_pushdown", + "(a_string varchar(1), a_char char(1), a_bigint bigint)", + ImmutableList.of( + "'A', 'A', 1", + "'B', 'B', 2", + "'a', 'a', 3", + "'b', 'b', 4"))) { + // case-sensitive functions prevent pushdown + assertConditionallyPushedDown( + getSession(), + "SELECT max(a_string), min(a_string), max(a_char), min(a_char) FROM " + table.getName(), + expectAggregationPushdown, + aggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES ('b', 'A', 'b', 'A')"); + // distinct over case-sensitive column prevents pushdown + assertConditionallyPushedDown( + getSession(), + "SELECT distinct a_string FROM " + table.getName(), + expectAggregationPushdown, + groupingAggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES 'A', 'B', 'a', 'b'"); + assertConditionallyPushedDown( + getSession(), + "SELECT distinct a_char FROM " + table.getName(), + expectAggregationPushdown, + groupingAggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES 'A', 'B', 'a', 'b'"); + // case-sensitive grouping sets prevent pushdown + assertConditionallyPushedDown(getSession(), + "SELECT a_string, count(*) FROM " + table.getName() + " GROUP BY a_string", + expectAggregationPushdown, + groupingAggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES ('A', BIGINT '1'), ('a', BIGINT '1'), ('b', BIGINT '1'), ('B', BIGINT '1')"); + assertConditionallyPushedDown(getSession(), + "SELECT a_char, count(*) FROM " + table.getName() + " GROUP BY a_char", + expectAggregationPushdown, + groupingAggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES ('A', BIGINT '1'), ('B', BIGINT '1'), ('a', BIGINT '1'), ('b', BIGINT '1')"); + + // case-insensitive functions can still be pushed down as long as grouping sets are not case-sensitive + assertThat(query("SELECT count(a_string), count(a_char) FROM " + table.getName())).isFullyPushedDown(); + assertThat(query("SELECT count(a_string), count(a_char) FROM " + table.getName() + " GROUP BY a_bigint")).isFullyPushedDown(); + + // DISTINCT over case-sensitive columns prevents pushdown + assertConditionallyPushedDown(getSession(), + "SELECT count(DISTINCT a_string) FROM " + table.getName(), + expectAggregationPushdown, + groupingAggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES BIGINT '4'"); + assertConditionallyPushedDown(getSession(), + "SELECT count(DISTINCT a_char) FROM " + table.getName(), + expectAggregationPushdown, + groupingAggregationOverTableScan) + .skippingTypesCheck() + .matches("VALUES BIGINT '4'"); + + // TODO: multiple count(DISTINCT expr) cannot be pushed down until https://github.com/trinodb/trino/pull/8562 gets merged + assertThat(query("SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName())) + .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); + assertThat(query("SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName())) + .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); + } + } + @Test public void testAggregationWithUnsupportedResultType() { @@ -265,7 +348,11 @@ public void testDistinctAggregationPushdown() // distinct aggregation with GROUP BY assertThat(query(withMarkDistinct, "SELECT count(DISTINCT nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); // distinct aggregation with varchar - assertThat(query(withMarkDistinct, "SELECT count(DISTINCT comment) FROM nation")).isFullyPushedDown(); + assertConditionallyPushedDown( + withMarkDistinct, + "SELECT count(DISTINCT comment) FROM nation", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY), + node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)))); // two distinct aggregations assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation")) .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); @@ -281,7 +368,11 @@ public void testDistinctAggregationPushdown() // distinct aggregation with GROUP BY assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); // distinct aggregation with varchar - assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT comment) FROM nation")).isFullyPushedDown(); + assertConditionallyPushedDown( + withoutMarkDistinct, + "SELECT count(DISTINCT comment) FROM nation", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY), + node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)))); // two distinct aggregations assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation")) .isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class); @@ -583,7 +674,7 @@ public void testLimitPushdown() aggregationOverTableScan); assertConditionallyPushedDown( getSession(), - "SELECT regionkey, max(name) FROM nation GROUP BY regionkey LIMIT 5", + "SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey LIMIT 5", hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN), aggregationOverTableScan); @@ -1042,7 +1133,7 @@ public void testJoinPushdown() } } - private void assertConditionallyPushedDown( + private QueryAssert assertConditionallyPushedDown( Session session, @Language("SQL") String query, boolean condition, @@ -1050,10 +1141,10 @@ private void assertConditionallyPushedDown( { QueryAssert queryAssert = assertThat(query(session, query)); if (condition) { - queryAssert.isFullyPushedDown(); + return queryAssert.isFullyPushedDown(); } else { - queryAssert.isNotFullyPushedDown(otherwiseExpected); + return queryAssert.isNotFullyPushedDown(otherwiseExpected); } } diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index f61e8c260df69..cc23b3ff97ed2 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -144,7 +144,7 @@ public ClickHouseClient( ImmutableSet.builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) - .add(new ImplementMinMax()) + .add(new ImplementMinMax(false)) // TODO: Revisit once https://github.com/trinodb/trino/issues/7100 is resolved .add(new ImplementSum(ClickHouseClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) @@ -159,6 +159,13 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List aggregates, Map assignments, List> groupingSets) + { + // TODO: Remove override once https://github.com/trinodb/trino/issues/7100 is resolved. Currently pushdown for textual types is not tested and may lead to incorrect results. + return preventTextualTypeAggregationPushdown(groupingSets); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(Types.DECIMAL, Optional.of("Decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index 8ea939e3d3498..35c7ad5e4c98c 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -157,7 +157,7 @@ public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, T ImmutableSet.builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) - .add(new ImplementMinMax()) + .add(new ImplementMinMax(false)) .add(new ImplementSum(MySqlClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) @@ -176,6 +176,13 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List aggregates, Map assignments, List> groupingSets) + { + // Remote database can be case insensitive. + return preventTextualTypeAggregationPushdown(groupingSets); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 30b6f44fda973..c1ffb097169da 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -283,7 +283,7 @@ public PostgreSqlClient( ImmutableSet.builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) - .add(new ImplementMinMax()) + .add(new ImplementMinMax(false)) .add(new ImplementSum(PostgreSqlClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) @@ -695,6 +695,13 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List aggregates, Map assignments, List> groupingSets) + { + // Postgres sorts textual types differently compared to Trino so we cannot safely pushdown any aggregations which take a text type as an input or as part of grouping set + return preventTextualTypeAggregationPushdown(groupingSets); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index d3228e7629165..c9a25ff55fc8b 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -169,7 +169,7 @@ public SqlServerClient(BaseJdbcConfig config, SqlServerConfig sqlServerConfig, C ImmutableSet.builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) - .add(new ImplementMinMax()) + .add(new ImplementMinMax(false)) .add(new ImplementSum(SqlServerClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) @@ -424,6 +424,13 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List aggregates, Map assignments, List> groupingSets) + { + // Remote database can be case insensitive. + return preventTextualTypeAggregationPushdown(groupingSets); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));