Skip to content

Commit

Permalink
Implement count(distinct) pushdown for PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjo2144 authored and losipiuk committed Sep 2, 2021
1 parent 25965e9 commit d55bb1c
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ public PlanOptimizers(
new RemoveRedundantExists(),
new ImplementFilteredAggregations(metadata),
new SingleDistinctAggregationToGroupBy(),
new MultipleDistinctAggregationToMarkDistinct(),
new MergeLimitWithDistinct(),
new PruneCountAggregationOverScalar(metadata),
new PruneOrderByInAggregation(metadata),
Expand Down Expand Up @@ -675,7 +674,8 @@ public PlanOptimizers(
new RemoveEmptyExceptBranches(),
new RemoveRedundantIdentityProjections(),
new PushAggregationThroughOuterJoin(),
new ReplaceRedundantJoinWithSource())), // Run this after PredicatePushDown optimizer as it inlines filter constants
new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants
new MultipleDistinctAggregationToMarkDistinct())), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector
inlineProjections,
simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations
pushProjectionIntoTableScanOptimizer,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.VarcharType;

import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.distinct;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.hasFilter;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

/**
* Implements {@code count(DISTINCT x)}.
*/
public class ImplementCountDistinct
implements AggregateFunctionRule
{
private static final Capture<Variable> INPUT = newCapture();

private final JdbcTypeHandle bigintTypeHandle;
private final boolean isRemoteCollationSensitive;

/**
* @param bigintTypeHandle A {@link JdbcTypeHandle} that will be mapped to {@link BigintType} by {@link JdbcClient#toColumnMapping}.
*/
public ImplementCountDistinct(JdbcTypeHandle bigintTypeHandle, boolean isRemoteCollationSensitive)
{
this.bigintTypeHandle = requireNonNull(bigintTypeHandle, "bigintTypeHandle is null");
this.isRemoteCollationSensitive = isRemoteCollationSensitive;
}

@Override
public Pattern<AggregateFunction> getPattern()
{
return Pattern.typeOf(AggregateFunction.class)
.with(distinct().equalTo(true))
.with(hasFilter().equalTo(false))
.with(functionName().equalTo("count"))
.with(singleInput().matching(variable().capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

boolean isCaseSensitiveType = columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType;
if (aggregateFunction.isDistinct() && !isRemoteCollationSensitive && isCaseSensitiveType) {
// Remote database is case insensitive or compares values differently from Trino
return Optional.empty();
}

return Optional.of(new JdbcExpression(
format("count(DISTINCT %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
bigintTypeHandle));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_CANCELLATION;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_CREATE_TABLE;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_INSERT;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN;
Expand All @@ -82,6 +84,7 @@
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR;
import static io.trino.testing.assertions.Assert.assertEventually;
import static io.trino.testing.sql.TestTable.randomTableSuffix;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.concurrent.Executors.newCachedThreadPool;
Expand Down Expand Up @@ -231,6 +234,8 @@ public void testCaseSensitiveAggregationPushdown()
}

boolean supportsPushdownWithVarcharInequality = hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY);
boolean supportsCountDistinctPushdown = hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT);

PlanMatchPattern aggregationOverTableScan = node(AggregationNode.class, node(TableScanNode.class));
PlanMatchPattern groupingAggregationOverTableScan = node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class)));
try (TestTable table = new TestTable(
Expand Down Expand Up @@ -297,11 +302,19 @@ public void testCaseSensitiveAggregationPushdown()
.skippingTypesCheck()
.matches("VALUES BIGINT '4'");

// TODO: multiple count(DISTINCT expr) cannot be pushed down until https://github.com/trinodb/trino/pull/8562 gets merged
assertThat(query("SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName()))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
assertThat(query("SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName()))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
assertConditionallyPushedDown(getSession(),
"SELECT count(DISTINCT a_string), count(DISTINCT a_bigint) FROM " + table.getName(),
supportsPushdownWithVarcharInequality && supportsCountDistinctPushdown,
node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class)))))
.skippingTypesCheck()
.matches("VALUES (BIGINT '4', BIGINT '4')");

assertConditionallyPushedDown(getSession(),
"SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName(),
supportsPushdownWithVarcharInequality && supportsCountDistinctPushdown,
node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class)))))
.skippingTypesCheck()
.matches("VALUES (BIGINT '4', BIGINT '4')");
}
}

Expand Down Expand Up @@ -354,11 +367,17 @@ public void testDistinctAggregationPushdown()
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))));
// two distinct aggregations
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation"))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
assertConditionallyPushedDown(
withMarkDistinct,
"SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT),
node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class))))));
// distinct aggregation and a non-distinct aggregation
assertThat(query(withMarkDistinct, "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation"))
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
assertConditionallyPushedDown(
withMarkDistinct,
"SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT),
node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class))))));

Session withoutMarkDistinct = Session.builder(getSession())
.setSystemProperty(USE_MARK_DISTINCT, "false")
Expand All @@ -374,11 +393,17 @@ public void testDistinctAggregationPushdown()
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))));
// two distinct aggregations
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation"))
.isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class);
assertConditionallyPushedDown(
withoutMarkDistinct,
"SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT),
node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class)))));
// distinct aggregation and a non-distinct aggregation
assertThat(query(withoutMarkDistinct, "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation"))
.isNotFullyPushedDown(AggregationNode.class, ExchangeNode.class, ExchangeNode.class);
assertConditionallyPushedDown(
withoutMarkDistinct,
"SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT),
node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class)))));
}

@Test
Expand Down Expand Up @@ -424,6 +449,62 @@ public void testNumericAggregationPushdown()
}
}

@Test
public void testCountDistinctWithStringTypes()
{
if (!(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_INSERT))) {
throw new SkipException("Unable to CREATE TABLE to test count distinct");
}

List<String> rows = Stream.of("a", "b", "A", "B", " a ", "a", "b", " b ", "ą")
.map(value -> format("'%1$s', '%1$s'", value))
.collect(toImmutableList());
String tableName = "distinct_strings" + randomTableSuffix();

try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(t_char CHAR(5), t_varchar VARCHAR(5))", rows)) {
if (!(hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN) && hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY))) {
// disabling hash generation to prevent extra projections in the plan which make it hard to write matchers for isNotFullyPushedDown
Session optimizeHashGenerationDisabled = Session.builder(getSession())
.setSystemProperty("optimize_hash_generation", "false")
.build();

// It is not captured in the `isNotFullyPushedDown` calls (can't do that) but depending on the connector in use some aggregations
// still can be pushed down to connector.
// If `SUPPORTS_AGGREGATION_PUSHDOWN == false` but `SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY == true` the DISTINCT part of aggregation
// will still be pushed down to connector as `GROUP BY`. Only the `count` part will remain on the Trino side.
// If `SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY == false` both parts of aggregation will be executed on Trino side.

assertThat(query(optimizeHashGenerationDisabled, "SELECT count(DISTINCT t_varchar) FROM " + testTable.getName()))
.matches("VALUES BIGINT '7'")
.isNotFullyPushedDown(AggregationNode.class);
assertThat(query(optimizeHashGenerationDisabled, "SELECT count(DISTINCT t_char) FROM " + testTable.getName()))
.matches("VALUES BIGINT '7'")
.isNotFullyPushedDown(AggregationNode.class);

assertThat(query("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName()))
.matches("VALUES (BIGINT '7', BIGINT '7')")
.isNotFullyPushedDown(MarkDistinctNode.class, ExchangeNode.class, ExchangeNode.class, ProjectNode.class);
}
else {
// Single count(DISTINCT ...) can be pushed even down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY
assertThat(query("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName()))
.matches("VALUES BIGINT '7'")
.isFullyPushedDown();

// Single count(DISTINCT ...) can be pushed down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY
assertThat(query("SELECT count(DISTINCT t_char) FROM " + testTable.getName()))
.matches("VALUES BIGINT '7'")
.isFullyPushedDown();

assertConditionallyPushedDown(
getSession(),
"SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName(),
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT),
node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class))))));
}
}
}

/**
* Creates a table with columns {@code short_decimal decimal(9, 3), long_decimal decimal(30, 10), t_double double, a_bigint bigint} populated
* with the provided rows.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.phoenix;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.plugin.jdbc.UnsupportedTypeHandling;
Expand All @@ -31,10 +32,13 @@
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.UNSUPPORTED_TYPE_HANDLING;
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
import static io.trino.plugin.phoenix.PhoenixQueryRunner.createPhoenixQueryRunner;
import static io.trino.testing.sql.TestTable.randomTableSuffix;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -248,6 +252,23 @@ public void testVarcharCharComparison()
}
}

// Overridden because Phoenix requires a ROWID column
@Override
public void testCountDistinctWithStringTypes()
{
assertThatThrownBy(super::testCountDistinctWithStringTypes).hasStackTraceContaining("Illegal data. CHAR types may only contain single byte characters");
// Skipping the ą test case because it is not supported
List<String> rows = Streams.mapWithIndex(Stream.of("a", "b", "A", "B", " a ", "a", "b", " b "), (value, idx) -> String.format("%d, '%2$s', '%2$s'", idx, value))
.collect(toImmutableList());
String tableName = "count_distinct_strings" + randomTableSuffix();

try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(id int, t_char CHAR(5), t_varchar VARCHAR(5)) WITH (ROWKEYS='id')", rows)) {
assertQuery("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName(), "VALUES 6");
assertQuery("SELECT count(DISTINCT t_char) FROM " + testTable.getName(), "VALUES 6");
assertQuery("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName(), "VALUES (6, 6)");
}
}

@Test
public void testSchemaOperations()
{
Expand Down
Loading

0 comments on commit d55bb1c

Please sign in to comment.