diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 410ab848f1228..de5c919fca0df 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -485,7 +485,6 @@ public PlanOptimizers( new RemoveRedundantExists(), new ImplementFilteredAggregations(metadata), new SingleDistinctAggregationToGroupBy(), - new MultipleDistinctAggregationToMarkDistinct(), new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata), new PruneOrderByInAggregation(metadata), @@ -675,7 +674,8 @@ public PlanOptimizers( new RemoveEmptyExceptBranches(), new RemoveRedundantIdentityProjections(), new PushAggregationThroughOuterJoin(), - new ReplaceRedundantJoinWithSource())), // Run this after PredicatePushDown optimizer as it inlines filter constants + new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants + new MultipleDistinctAggregationToMarkDistinct())), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector inlineProjections, simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations pushProjectionIntoTableScanOptimizer, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCountDistinct.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCountDistinct.java new file mode 100644 index 0000000000000..9424346a8f983 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCountDistinct.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.AggregateFunctionRule; +import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.CharType; +import io.trino.spi.type.VarcharType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.distinct; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.hasFilter; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; +import static io.trino.spi.type.BigintType.BIGINT; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * Implements {@code count(DISTINCT x)}. + */ +public class ImplementCountDistinct + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + private final JdbcTypeHandle bigintTypeHandle; + private final boolean isRemoteCollationSensitive; + + /** + * @param bigintTypeHandle A {@link JdbcTypeHandle} that will be mapped to {@link BigintType} by {@link JdbcClient#toColumnMapping}. + */ + public ImplementCountDistinct(JdbcTypeHandle bigintTypeHandle, boolean isRemoteCollationSensitive) + { + this.bigintTypeHandle = requireNonNull(bigintTypeHandle, "bigintTypeHandle is null"); + this.isRemoteCollationSensitive = isRemoteCollationSensitive; + } + + @Override + public Pattern getPattern() + { + return Pattern.typeOf(AggregateFunction.class) + .with(distinct().equalTo(true)) + .with(hasFilter().equalTo(false)) + .with(functionName().equalTo("count")) + .with(singleInput().matching(variable().capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verify(aggregateFunction.getOutputType() == BIGINT); + + boolean isCaseSensitiveType = columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType; + if (aggregateFunction.isDistinct() && !isRemoteCollationSensitive && isCaseSensitiveType) { + // Remote database is case insensitive or compares values differently from Trino + return Optional.empty(); + } + + return Optional.of(new JdbcExpression( + format("count(DISTINCT %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), + bigintTypeHandle)); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 91eaebff2aa17..07d5265a5ee30 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -64,12 +64,14 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_CANCELLATION; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_CREATE_TABLE; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_INSERT; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN; @@ -82,6 +84,7 @@ import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR; import static io.trino.testing.assertions.Assert.assertEventually; +import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.concurrent.Executors.newCachedThreadPool; @@ -231,6 +234,8 @@ public void testCaseSensitiveAggregationPushdown() } boolean supportsPushdownWithVarcharInequality = hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY); + boolean supportsCountDistinctPushdown = hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT); + PlanMatchPattern aggregationOverTableScan = node(AggregationNode.class, node(TableScanNode.class)); PlanMatchPattern groupingAggregationOverTableScan = node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))); try (TestTable table = new TestTable( @@ -297,11 +302,19 @@ public void testCaseSensitiveAggregationPushdown() .skippingTypesCheck() .matches("VALUES BIGINT '4'"); - // TODO: multiple count(DISTINCT expr) cannot be pushed down until https://github.com/trinodb/trino/pull/8562 gets merged - assertThat(query("SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName())) - .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); - assertThat(query("SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName())) - .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); + assertConditionallyPushedDown(getSession(), + "SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName(), + supportsPushdownWithVarcharInequality && supportsCountDistinctPushdown, + node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class))))) + .skippingTypesCheck() + .matches("VALUES (BIGINT '4', BIGINT '4')"); + + assertConditionallyPushedDown(getSession(), + "SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName(), + supportsPushdownWithVarcharInequality && supportsCountDistinctPushdown, + node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class))))) + .skippingTypesCheck() + .matches("VALUES (BIGINT '4', BIGINT '4')"); } } @@ -354,11 +367,17 @@ public void testDistinctAggregationPushdown() hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY), node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)))); // two distinct aggregations - assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation")) - .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); + assertConditionallyPushedDown( + withMarkDistinct, + "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation", + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), + node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); // distinct aggregation and a non-distinct aggregation - assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation")) - .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); + assertConditionallyPushedDown( + withMarkDistinct, + "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation", + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), + node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); Session withoutMarkDistinct = Session.builder(getSession()) .setSystemProperty(USE_MARK_DISTINCT, "false") @@ -374,11 +393,17 @@ public void testDistinctAggregationPushdown() hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY), node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)))); // two distinct aggregations - assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation")) - .isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class); + assertConditionallyPushedDown( + withoutMarkDistinct, + "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation", + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), + node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class))))); // distinct aggregation and a non-distinct aggregation - assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation")) - .isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class); + assertConditionallyPushedDown( + withoutMarkDistinct, + "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation", + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), + node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class))))); } @Test @@ -424,6 +449,62 @@ public void testNumericAggregationPushdown() } } + @Test + public void testCountDistinctWithStringTypes() + { + if (!(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_INSERT))) { + throw new SkipException("Unable to CREATE TABLE to test count distinct"); + } + + List rows = Stream.of("a", "b", "A", "B", " a ", "a", "b", " b ", "ą") + .map(value -> format("'%1$s', '%1$s'", value)) + .collect(toImmutableList()); + String tableName = "distinct_strings" + randomTableSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(t_char CHAR(5), t_varchar VARCHAR(5))", rows)) { + if (!(hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN) && hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY))) { + // disabling hash generation to prevent extra projections in the plan which make it hard to write matchers for isNotFullyPushedDown + Session optimizeHashGenerationDisabled = Session.builder(getSession()) + .setSystemProperty("optimize_hash_generation", "false") + .build(); + + // It is not captured in the `isNotFullyPushedDown` calls (can't do that) but depending on the connector in use some aggregations + // still can be pushed down to connector. + // If `SUPPORTS_AGGREGATION_PUSHDOWN == false` but `SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY == true` the DISTINCT part of aggregation + // will still be pushed down to connector as `GROUP BY`. Only the `count` part will remain on the Trino side. + // If `SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY == false` both parts of aggregation will be executed on Trino side. + + assertThat(query(optimizeHashGenerationDisabled, "SELECT count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES BIGINT '7'") + .isNotFullyPushedDown(AggregationNode.class); + assertThat(query(optimizeHashGenerationDisabled, "SELECT count(DISTINCT t_char) FROM " + testTable.getName())) + .matches("VALUES BIGINT '7'") + .isNotFullyPushedDown(AggregationNode.class); + + assertThat(query("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES (BIGINT '7', BIGINT '7')") + .isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class); + } + else { + // Single count(DISTINCT ...) can be pushed even down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES BIGINT '7'") + .isFullyPushedDown(); + + // Single count(DISTINCT ...) can be pushed down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_char) FROM " + testTable.getName())) + .matches("VALUES BIGINT '7'") + .isFullyPushedDown(); + + assertConditionallyPushedDown( + getSession(), + "SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName(), + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), + node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); + } + } + } + /** * Creates a table with columns {@code short_decimal decimal(9, 3), long_decimal decimal(30, 10), t_double double, a_bigint bigint} populated * with the provided rows. diff --git a/plugin/trino-phoenix/src/test/java/io/trino/plugin/phoenix/TestPhoenixConnectorTest.java b/plugin/trino-phoenix/src/test/java/io/trino/plugin/phoenix/TestPhoenixConnectorTest.java index 2e05cd4ed8b8e..eef6f10b0a5f4 100644 --- a/plugin/trino-phoenix/src/test/java/io/trino/plugin/phoenix/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix/src/test/java/io/trino/plugin/phoenix/TestPhoenixConnectorTest.java @@ -14,6 +14,7 @@ package io.trino.plugin.phoenix; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.plugin.jdbc.UnsupportedTypeHandling; @@ -31,10 +32,13 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.List; +import java.util.stream.Stream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.UNSUPPORTED_TYPE_HANDLING; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; import static io.trino.plugin.phoenix.PhoenixQueryRunner.createPhoenixQueryRunner; +import static io.trino.testing.sql.TestTable.randomTableSuffix; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertTrue; @@ -248,6 +252,23 @@ public void testVarcharCharComparison() } } + // Overridden because Phoenix requires a ROWID column + @Override + public void testCountDistinctWithStringTypes() + { + assertThatThrownBy(super::testCountDistinctWithStringTypes).hasStackTraceContaining("Illegal data. CHAR types may only contain single byte characters"); + // Skipping the ą test case because it is not supported + List rows = Streams.mapWithIndex(Stream.of("a", "b", "A", "B", " a ", "a", "b", " b "), (value, idx) -> String.format("%d, '%2$s', '%2$s'", idx, value)) + .collect(toImmutableList()); + String tableName = "count_distinct_strings" + randomTableSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(id int, t_char CHAR(5), t_varchar VARCHAR(5)) WITH (ROWKEYS='id')", rows)) { + assertQuery("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName(), "VALUES 6"); + assertQuery("SELECT count(DISTINCT t_char) FROM " + testTable.getName(), "VALUES 6"); + assertQuery("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName(), "VALUES (6, 6)"); + } + } + @Test public void testSchemaOperations() { diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java index e5b801bee406f..253fc12d0b2fb 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.plugin.jdbc.UnsupportedTypeHandling; @@ -32,6 +33,7 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.List; +import java.util.stream.Stream; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.UNSUPPORTED_TYPE_HANDLING; @@ -52,6 +54,7 @@ import static io.trino.sql.tree.SortItem.NullOrdering.LAST; import static io.trino.sql.tree.SortItem.Ordering.ASCENDING; import static io.trino.sql.tree.SortItem.Ordering.DESCENDING; +import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -266,6 +269,23 @@ public void testVarcharCharComparison() } } + // Overridden because Phoenix requires a ROWID column + @Override + public void testCountDistinctWithStringTypes() + { + assertThatThrownBy(super::testCountDistinctWithStringTypes).hasStackTraceContaining("Illegal data. CHAR types may only contain single byte characters"); + // Skipping the ą test case because it is not supported + List rows = Streams.mapWithIndex(Stream.of("a", "b", "A", "B", " a ", "a", "b", " b "), (value, idx) -> String.format("%d, '%2$s', '%2$s'", idx, value)) + .collect(toImmutableList()); + String tableName = "count_distinct_strings" + randomTableSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(id int, t_char CHAR(5), t_varchar VARCHAR(5)) WITH (ROWKEYS='id')", rows)) { + assertQuery("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName(), "VALUES 6"); + assertQuery("SELECT count(DISTINCT t_char) FROM " + testTable.getName(), "VALUES 6"); + assertQuery("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName(), "VALUES (6, 6)"); + } + } + @Test public void testSchemaOperations() { diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 40011ae236bdd..349fb2512aedc 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -50,6 +50,7 @@ import io.trino.plugin.jdbc.expression.ImplementCorr; import io.trino.plugin.jdbc.expression.ImplementCount; import io.trino.plugin.jdbc.expression.ImplementCountAll; +import io.trino.plugin.jdbc.expression.ImplementCountDistinct; import io.trino.plugin.jdbc.expression.ImplementCovariancePop; import io.trino.plugin.jdbc.expression.ImplementCovarianceSamp; import io.trino.plugin.jdbc.expression.ImplementMinMax; @@ -282,8 +283,9 @@ public PostgreSqlClient( this::quoted, ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) - .add(new ImplementCount(bigintTypeHandle)) .add(new ImplementMinMax(false)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementCountDistinct(bigintTypeHandle, false)) .add(new ImplementSum(PostgreSqlClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index 6c56997a07bdd..aa41a46fc7a52 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -97,7 +97,7 @@ public void testImplementCount() testImplementAggregation( new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), - Optional.empty()); + Optional.of("count(DISTINCT \"c_bigint\")")); // count() FILTER (WHERE ...) diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 389a55d2de9a7..dcd682d6e6966 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -93,6 +93,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE: case SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION: case SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION: + case SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT: return true; case SUPPORTS_JOIN_PUSHDOWN: diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java index 5a5ef3c719416..cff3c8c241e83 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java @@ -35,6 +35,7 @@ public enum TestingConnectorBehavior SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE(false), SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION(false), SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION(false), + SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT(false), SUPPORTS_JOIN_PUSHDOWN( // Currently no connector supports Join pushdown by default. JDBC connectors may support Join pushdown and BaseJdbcConnectorTest